Skip to content

Commit 102f77d

Browse files
Merge pull request #72 from IBM/main
[pull] main from IBM:main
2 parents abca422 + 0b7a2db commit 102f77d

File tree

20 files changed

+2695
-658
lines changed

20 files changed

+2695
-658
lines changed

Cargo.lock

Lines changed: 7 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ RUN cd server && \
164164
make gen-server && \
165165
pip install ".[accelerate]" --no-cache-dir
166166

167+
# temp: install newer transformers lib that optimum clashes with
168+
RUN pip install transformers==4.40.0 tokenizers==0.19.1 --no-cache-dir
169+
167170
# Patch codegen model changes into transformers
168171
RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
169172

@@ -288,6 +291,9 @@ COPY server server
288291
# Ref: https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x
289292
RUN cd server && make gen-server && pip install ".[accelerate, ibm-fms, onnx-gpu, quantize]" --no-cache-dir --extra-index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
290293

294+
# temp: install newer transformers lib that optimum clashes with
295+
RUN pip install transformers==4.40.0 tokenizers==0.19.1 --no-cache-dir
296+
291297
# Patch codegen model changes into transformers 4.35
292298
RUN cp server/transformers_patch/modeling_codegen.py ${SITE_PACKAGES}/transformers/models/codegen/modeling_codegen.py
293299

integration_tests/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
gen-client:
22
# Compile protos
3-
pip install grpcio-tools==1.62.1 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
3+
pip install grpcio-tools==1.62.2 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
44
mkdir text_generation_tests/pb || true
55
python -m grpc_tools.protoc -I../proto --python_out=text_generation_tests/pb \
66
--grpc_python_out=text_generation_tests/pb --mypy_out=text_generation_tests/pb ../proto/generation.proto

integration_tests/poetry.lock

Lines changed: 138 additions & 122 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

integration_tests/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ python = ">=3.11"
99

1010
[tool.poetry.group.dev.dependencies]
1111
protobuf = "^4.25.3"
12-
grpcio-tools = "^1.62.1"
12+
grpcio-tools = "^1.62.2"
1313
pytest = "^8.1.1"
1414
pytest-asyncio = "^0.23.6"
1515
requests = "^2.31.0"

integration_tests/sample_client.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import time
2+
import grpc
3+
from google.protobuf import json_format
4+
from text_generation_tests.pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2
5+
6+
7+
def get_streaming_response_tgis(response):
8+
stop = False
9+
generated_tokens = 0
10+
while not stop:
11+
try:
12+
x = next(response)
13+
timestamp = time.time_ns()
14+
data = json_format.MessageToDict(x)
15+
# skip first response (tokenizer output only)
16+
if "inputTokenCount" not in data:
17+
n_tokens = data["generatedTokenCount"] - generated_tokens
18+
generated_tokens = data["generatedTokenCount"]
19+
yield data, n_tokens, timestamp, True, None
20+
except Exception as e:
21+
timestamp = time.time_ns()
22+
yield None, 0, timestamp, False, e
23+
24+
25+
channel = grpc.insecure_channel("localhost:8033")
26+
stub = gpb2.GenerationServiceStub(channel)
27+
max_new_tokens = 100
28+
29+
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
30+
num_req = 0
31+
while True:
32+
prompt_input = input(f"\n{num_req}) Enter a prompt:\n")
33+
34+
print("-" * 40)
35+
print("Output:")
36+
prompt = template.format(prompt_input)
37+
sample_request = {
38+
"model_id": "dummy-model-name",
39+
"request": {"text": prompt},
40+
"params": {
41+
"method": "GREEDY",
42+
"stopping": {
43+
"max_new_tokens": max_new_tokens,
44+
"min_new_tokens": max_new_tokens,
45+
},
46+
},
47+
}
48+
message = json_format.ParseDict(sample_request, pb2.SingleGenerationRequest())
49+
output = []
50+
total_time = 0
51+
response = stub.GenerateStream(message)
52+
response_generator = get_streaming_response_tgis(response)
53+
t0 = time.time_ns()
54+
response = ""
55+
stop = False
56+
while not stop:
57+
r, n_tokens, t, ok, err = next(response_generator)
58+
59+
if not ok:
60+
stop = True
61+
# check if we have reached end of stream
62+
if type(err) is StopIteration:
63+
continue
64+
duration = (t - t0) / 1000.0 / 1000.0
65+
record = {
66+
"response": r,
67+
"ok": ok,
68+
"error": str(err),
69+
"timestamp": t,
70+
"duration_ms": duration,
71+
"n_tokens": n_tokens,
72+
}
73+
total_time += duration
74+
response += r["text"]
75+
output.append(record)
76+
t0 = t
77+
78+
# print(json.dumps(output, indent=4))
79+
print("-" * 40)
80+
print(response)
81+
print("-" * 40)
82+
print(f"Total_time : {total_time}ms")
83+
print(f"Time_per_token : {total_time/max_new_tokens}ms")
84+
print("-" * 40)
85+
num_req += 1

router/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ nohash-hasher = "^0.2.0"
2727
num = "^0.4.2"
2828
num_cpus = "^1.16.0"
2929
hyper = "^0.14.28" # Override to address CVE-2023-26964
30+
h2 = "^0.3.26 " # Override to address CVEs
3031
openssl = "^0.10.64" # Override to address WS-2023-0082, WS-2023-0083, WS-2023-0195
3132
openssl-sys = "^0.9.102" # Override to address WS-2023-0082, WS-2023-0083, WS-2023-0195
3233
rustls-webpki = "0.102.2" # Override to address WS-2023-0305, CVE-2018-16875
@@ -37,7 +38,7 @@ thiserror = "^1.0.57"
3738
tokenizers = "0.19.1"
3839
tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
3940
tokio-rustls = "^0.26.0"
40-
rustls = "0.22.3"
41+
rustls = "0.22.4"
4142
tracing = "^0.1.40"
4243
prost = "^0.12.4"
4344
tonic = { version = "^0.11.0", features = ["tls"] }

router/src/batcher.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,19 @@ impl<'a> TokenProcessor<'a> {
839839
let request_id = output.request_id;
840840
let next_token_id = output.token_id;
841841

842-
let e = self
843-
.entries
844-
.get_mut(&request_id)
845-
.expect("ID not found. This is a bug.");
842+
let e = self.entries.get_mut(&request_id);
843+
844+
// if a client cancelled a request and speculative decoding is
845+
// enabled, it's possible that the request will get removed
846+
// from entries table, but there can still be tokens in outputs stream
847+
// corresponding to that request. ideally we could defer removing
848+
// the request_id from the entries table until all tokens have been
849+
// processed...but for now let's just ignore them.
850+
if e.is_none() {
851+
continue;
852+
}
853+
854+
let e = e.unwrap();
846855

847856
let is_stream = e.stream_tx.is_some();
848857
let stop_seqs = &e.request.parameters.stop_seqs;

server/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ all: install run-dev
44
.PHONY: gen-server
55
gen-server:
66
# Compile protos
7-
pip install grpcio-tools==1.62.1 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
7+
pip install grpcio-tools==1.62.2 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
88
mkdir -p text_generation_server/pb
99
python -m grpc_tools.protoc -I../proto \
1010
--python_out=text_generation_server/pb \

0 commit comments

Comments
 (0)