Skip to content

Commit 795c087

Browse files
Merge pull request #73 from opendatahub-io/main
Sync odh/main with odh/release
2 parents daaa6b6 + 102f77d commit 795c087

File tree

23 files changed

+3347
-1197
lines changed

23 files changed

+3347
-1197
lines changed

Cargo.lock

Lines changed: 288 additions & 203 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Dockerfile

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Global Args #################################################################
22
ARG BASE_UBI_IMAGE_TAG=9.3-1610
3-
ARG PROTOC_VERSION=25.2
3+
ARG PROTOC_VERSION=25.3
44
ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
55
# ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
66
ARG AUTO_GPTQ_VERSION=0.7.1
@@ -86,7 +86,7 @@ ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"
8686

8787
## Rust builder ################################################################
8888
# Specific debian version so that compatible glibc version is used
89-
FROM rust:1.77-bullseye as rust-builder
89+
FROM rust:1.77.2-bullseye as rust-builder
9090
ARG PROTOC_VERSION
9191

9292
ENV CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -164,6 +164,9 @@ RUN cd server && \
164164
make gen-server && \
165165
pip install ".[accelerate]" --no-cache-dir
166166

167+
# temp: install newer transformers lib that optimum clashes with
168+
RUN pip install transformers==4.40.0 tokenizers==0.19.1 --no-cache-dir
169+
167170
# Patch codegen model changes into transformers
168171
RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
169172

@@ -288,6 +291,9 @@ COPY server server
288291
# Ref: https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x
289292
RUN cd server && make gen-server && pip install ".[accelerate, ibm-fms, onnx-gpu, quantize]" --no-cache-dir --extra-index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
290293

294+
# temp: install newer transformers lib that optimum clashes with
295+
RUN pip install transformers==4.40.0 tokenizers==0.19.1 --no-cache-dir
296+
291297
# Patch codegen model changes into transformers 4.35
292298
RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
293299

integration_tests/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
gen-client:
22
# Compile protos
3-
pip install grpcio-tools==1.60.0 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
3+
pip install grpcio-tools==1.62.2 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
44
mkdir text_generation_tests/pb || true
55
python -m grpc_tools.protoc -I../proto --python_out=text_generation_tests/pb \
66
--grpc_python_out=text_generation_tests/pb --mypy_out=text_generation_tests/pb ../proto/generation.proto

integration_tests/poetry.lock

Lines changed: 303 additions & 308 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

integration_tests/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ python = ">=3.11"
99

1010
[tool.poetry.group.dev.dependencies]
1111
protobuf = "^4.25.3"
12-
grpcio-tools = "^1.62.1"
12+
grpcio-tools = "^1.62.2"
1313
pytest = "^8.1.1"
1414
pytest-asyncio = "^0.23.6"
1515
requests = "^2.31.0"

integration_tests/sample_client.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import time
2+
import grpc
3+
from google.protobuf import json_format
4+
from text_generation_tests.pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2
5+
6+
7+
def get_streaming_response_tgis(response):
8+
stop = False
9+
generated_tokens = 0
10+
while not stop:
11+
try:
12+
x = next(response)
13+
timestamp = time.time_ns()
14+
data = json_format.MessageToDict(x)
15+
# skip first response (tokenizer output only)
16+
if "inputTokenCount" not in data:
17+
n_tokens = data["generatedTokenCount"] - generated_tokens
18+
generated_tokens = data["generatedTokenCount"]
19+
yield data, n_tokens, timestamp, True, None
20+
except Exception as e:
21+
timestamp = time.time_ns()
22+
yield None, 0, timestamp, False, e
23+
24+
25+
channel = grpc.insecure_channel("localhost:8033")
26+
stub = gpb2.GenerationServiceStub(channel)
27+
max_new_tokens = 100
28+
29+
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
30+
num_req = 0
31+
while True:
32+
prompt_input = input(f"\n{num_req}) Enter a prompt:\n")
33+
34+
print("-" * 40)
35+
print("Output:")
36+
prompt = template.format(prompt_input)
37+
sample_request = {
38+
"model_id": "dummy-model-name",
39+
"request": {"text": prompt},
40+
"params": {
41+
"method": "GREEDY",
42+
"stopping": {
43+
"max_new_tokens": max_new_tokens,
44+
"min_new_tokens": max_new_tokens,
45+
},
46+
},
47+
}
48+
message = json_format.ParseDict(sample_request, pb2.SingleGenerationRequest())
49+
output = []
50+
total_time = 0
51+
response = stub.GenerateStream(message)
52+
response_generator = get_streaming_response_tgis(response)
53+
t0 = time.time_ns()
54+
response = ""
55+
stop = False
56+
while not stop:
57+
r, n_tokens, t, ok, err = next(response_generator)
58+
59+
if not ok:
60+
stop = True
61+
# check if we have reached end of stream
62+
if type(err) is StopIteration:
63+
continue
64+
duration = (t - t0) / 1000.0 / 1000.0
65+
record = {
66+
"response": r,
67+
"ok": ok,
68+
"error": str(err),
69+
"timestamp": t,
70+
"duration_ms": duration,
71+
"n_tokens": n_tokens,
72+
}
73+
total_time += duration
74+
response += r["text"]
75+
output.append(record)
76+
t0 = t
77+
78+
# print(json.dumps(output, indent=4))
79+
print("-" * 40)
80+
print(response)
81+
print("-" * 40)
82+
print(f"Total_time : {total_time}ms")
83+
print(f"Time_per_token : {total_time/max_new_tokens}ms")
84+
print("-" * 40)
85+
num_req += 1

launcher/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ authors = ["Olivier Dehaene"]
66
description = "Text Generation Launcher"
77

88
[dependencies]
9-
clap = { version = "4.5.3", features = ["derive", "env"] }
9+
clap = { version = "4.5.4", features = ["derive", "env"] }
1010
ctrlc = { version = "3.4.4", features = ["termination"] }
1111
nix = { version = "0.28.0", features = ["process", "signal"] }
12-
serde_json = "^1.0.114"
12+
serde_json = "^1.0.11"
1313
tracing = "0.1.40"
1414
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
1515
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }

router/Cargo.toml

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,34 @@ path = "src/main.rs"
1717
axum = { version = "0.6.20", features = ["json"] }
1818
axum-tracing-opentelemetry = "0.10.0"
1919
text-generation-client = { path = "client" }
20-
clap = { version = "^4.5.2", features = ["derive", "env"] }
20+
clap = { version = "^4.5.4", features = ["derive", "env"] }
2121
futures = "^0.3.30"
2222
flume = "^0.11.0"
2323
metrics = "0.21.1"
2424
metrics-exporter-prometheus = { version = "0.12.2", features = [] }
25-
moka = { version = "0.12.5", features = ["future"] }
25+
moka = { version = "0.12.6", features = ["future"] }
2626
nohash-hasher = "^0.2.0"
27-
num = "^0.4.1"
27+
num = "^0.4.2"
2828
num_cpus = "^1.16.0"
2929
hyper = "^0.14.28" # Override to address CVE-2023-26964
30+
h2 = "^0.3.26 " # Override to address CVEs
3031
openssl = "^0.10.64" # Override to address WS-2023-0082, WS-2023-0083, WS-2023-0195
31-
openssl-sys = "^0.9.101" # Override to address WS-2023-0082, WS-2023-0083, WS-2023-0195
32+
openssl-sys = "^0.9.102" # Override to address WS-2023-0082, WS-2023-0083, WS-2023-0195
3233
rustls-webpki = "0.102.2" # Override to address WS-2023-0305, CVE-2018-16875
3334
rand = "^0.8.5"
34-
serde = "^1.0.197"
35-
serde_json = "^1.0.114"
35+
serde = "^1.0.198"
36+
serde_json = "^1.0.116"
3637
thiserror = "^1.0.57"
37-
tokenizers = "0.15.2"
38-
tokio = { version = "1.36.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
39-
tokio-rustls = "^0.25.0"
40-
rustls = "0.22.2"
38+
tokenizers = "0.19.1"
39+
tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
40+
tokio-rustls = "^0.26.0"
41+
rustls = "0.22.4"
4142
tracing = "^0.1.40"
42-
prost = "^0.12.3"
43+
prost = "^0.12.4"
4344
tonic = { version = "^0.11.0", features = ["tls"] }
4445
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
4546
tracing-opentelemetry = "0.23.0"
46-
tokio-stream ="^0.1.14"
47+
tokio-stream ="^0.1.15"
4748
unicode-segmentation = "^1.11.0"
4849
unicode-truncate = "^0.2.0"
4950
opentelemetry = "0.22.0"

router/client/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ build="build.rs"
66

77
[dependencies]
88
futures = "^0.3.30"
9-
prost = "^0.12.3"
9+
prost = "^0.12.4"
1010
thiserror = "^1.0.58"
11-
tokio = { version = "1.36.0", features = ["sync"] }
11+
tokio = { version = "1.37.0", features = ["sync"] }
1212
tonic = "^0.11.0"
1313
tower = "^0.4.13"
1414
tracing = "^0.1.40"

router/src/batcher.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,19 @@ impl<'a> TokenProcessor<'a> {
839839
let request_id = output.request_id;
840840
let next_token_id = output.token_id;
841841

842-
let e = self
843-
.entries
844-
.get_mut(&request_id)
845-
.expect("ID not found. This is a bug.");
842+
let e = self.entries.get_mut(&request_id);
843+
844+
// if a client cancelled a request and speculative decoding is
845+
// enabled, it's possible that the request will get removed
846+
// from entries table, but there can still be tokens in outputs stream
847+
// corresponding to that request. ideally we could defer removing
848+
// the request_id from the entries table until all tokens have been
849+
// processed...but for now let's just ignore them.
850+
if e.is_none() {
851+
continue;
852+
}
853+
854+
let e = e.unwrap();
846855

847856
let is_stream = e.stream_tx.is_some();
848857
let stop_seqs = &e.request.parameters.stop_seqs;

0 commit comments

Comments
 (0)