Skip to content

Commit a39b8a8

Browse files
committed
feat: Basic script to test batching integrity of a deployed model
This runs a series of tests to ensure consistency of output when the same input is included in a (padded) batch, as well as when batches are modified via pruning and concatenation operations while requests are in progress.
1 parent 3e34359 commit a39b8a8

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
generation_pb2.py
2+
generation_pb2_grpc.py
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## Batching Integrity Verification
2+
3+
This script can be used to help verify correctness of models running in TGIS with dynamic/continuous batching.
4+
5+
False negatives are still possible since there is some amount of inconsistency expected due to the fixed precision floating point operations, for example in float16 and especially bfloat16.
6+
7+
8+
First compile protobuf stubs for the external API:
9+
```
10+
python -m grpc_tools.protoc -I../../proto --python_out=. --grpc_python_out=. generation.proto
11+
```
12+
13+
Then run the script, it currently assumes TGIS is running locally/port-forwarded on port 8033.
14+
```
15+
python batching_integrity_checks.py
16+
```
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
3+
This script can be used to help verify correctness of models running in TGIS
4+
with dynamic batching.
5+
6+
False negatives are still possible since there is some amount of inconsistency expected due to
7+
the fixed precision floating point operations, for example in float16 and especially bfloat16.
8+
9+
"""
10+
import concurrent.futures
11+
import time
12+
13+
import grpc
14+
import generation_pb2
15+
import generation_pb2_grpc
16+
17+
18+
class col:
19+
RED = '\033[31m'
20+
ENDC = '\033[m'
21+
GREEN = '\033[32m'
22+
YELLOW = '\033[33m'
23+
BLUE = '\033[34m'
24+
25+
26+
if __name__ == "__main__":
27+
channel = grpc.insecure_channel("localhost:8033")
28+
stub = generation_pb2_grpc.GenerationServiceStub(channel)
29+
30+
# 1) Common input + output length batch
31+
# -- Based on this determine unique stop sequence
32+
# 2) Variable input length, common output length batch
33+
# 3) Common input length, variable output length batch (using stop sequence from 1)
34+
# 4) Concatenation test, long, with short concurrently after tiny delay
35+
36+
model_id = "unused"
37+
38+
out_length = 80 # should be at least 50
39+
long_text = (
40+
"The core components include: a studio for new foundation models, generative AI and machine learning; "
41+
"a fit-for-purpose data store built on an open data lakehouse architecture; and a toolkit, to accelerate "
42+
"AI workflows that are built with responsibility, transparency and explainability."
43+
)
44+
short_text = "watsonx is a generative AI and data platform with a set of AI assistants."
45+
46+
# Get token counts for inputs
47+
tresp = stub.Tokenize(generation_pb2.BatchedTokenizeRequest(
48+
model_id=model_id,
49+
requests=[generation_pb2.TokenizeRequest(text=text) for text in (long_text, short_text)],
50+
return_tokens=False,
51+
))
52+
53+
long_tokens, short_tokens = (tr.token_count for tr in tresp.responses)
54+
55+
print(f"token counts: short={short_tokens}, long={long_tokens}")
56+
57+
assert 15 < short_tokens < long_tokens
58+
59+
truncate_to = short_tokens - 10
60+
61+
def send_requests(requests, truncate_to=0, stop_seqs=None, out_length=out_length, min_out_length=None):
62+
if min_out_length is None:
63+
min_out_length = out_length
64+
65+
return stub.Generate(generation_pb2.BatchedGenerationRequest(
66+
model_id=model_id,
67+
requests=requests,
68+
params=generation_pb2.Parameters(
69+
truncate_input_tokens=truncate_to,
70+
stopping=generation_pb2.StoppingCriteria(
71+
min_new_tokens=min_out_length,
72+
max_new_tokens=out_length,
73+
stop_sequences=stop_seqs,
74+
),
75+
),
76+
))
77+
78+
def log_result(batched_results, individual_results):
79+
matches = [x == y for (x, y) in zip(batched_results, individual_results)]
80+
if all(matches):
81+
print(col.GREEN + "PASS" + col.ENDC)
82+
else:
83+
print(f"{col.RED}FAIL{col.ENDC}: {matches.count(False)}/{len(matches)} mismatches: {matches}")
84+
#TODO improve how these diffs are printed
85+
print(f"BATCHED: {batched_results}")
86+
print(f"SINGLE : {individual_results}")
87+
88+
89+
#TODO first do single request consistency
90+
91+
#TODO can add a test here to check invariance to front-padding
92+
93+
94+
greqs = [generation_pb2.GenerationRequest(text=text) for text in (long_text, short_text)]
95+
96+
97+
### Test 1 #################################################################################
98+
############################################################################################
99+
test_name = "Common input and output sizes (tests basic batching)"
100+
print(f"\n{test_name}")
101+
batched_1 = [gr.text for gr in send_requests(greqs, truncate_to).responses]
102+
individual_1 = [send_requests([gr], truncate_to).responses[0].text for gr in greqs]
103+
log_result(batched_1, individual_1)
104+
105+
106+
### Test 2 #################################################################################
107+
############################################################################################
108+
test_name = "Variable input lengths, common output length (tests padded batch)"
109+
print(f"\n{test_name}")
110+
batched_2 = [gr.text for gr in send_requests(greqs).responses]
111+
individual_2 = [send_requests([gr]).responses[0].text for gr in greqs]
112+
log_result(batched_2, individual_2)
113+
114+
115+
### Test 3 #################################################################################
116+
############################################################################################
117+
# Find stop seq - a string in one of the output sequences that isn't in the other
118+
ss = None
119+
for i in range(20, out_length - 5):
120+
ss = batched_1[0][i:i+5]
121+
if ss not in batched_1[1]:
122+
break
123+
124+
if ss is None:
125+
print("\ncouldn't find stop seq to use for variable output length test")
126+
else:
127+
test_name = "Common input length, variable output lengths (tests batch pruning)"
128+
print(f"\n{test_name}")
129+
print(f"Using stop-sequence '{ss}'")
130+
batched_3 = [gr.text for gr in send_requests(greqs, truncate_to, [ss], min_out_length=1).responses]
131+
individual_3 = [
132+
send_requests([gr], truncate_to, [ss], min_out_length=1).responses[0].text for gr in greqs
133+
]
134+
log_result(batched_3, individual_3)
135+
136+
137+
### Test 4 #################################################################################
138+
############################################################################################
139+
test_name = "Short output interrupting long output (tests batch concatenation)"
140+
print(f"\n{test_name}")
141+
long_input_req, short_input_req = greqs
142+
individual_4 = [
143+
send_requests([short_input_req], out_length=100).responses[0].text,
144+
send_requests([long_input_req], out_length=10).responses[0].text,
145+
]
146+
147+
with concurrent.futures.ThreadPoolExecutor(1) as executor:
148+
future = executor.submit(send_requests, [short_input_req], out_length=100)
149+
time.sleep(0.25)
150+
short_resp = send_requests([long_input_req], out_length=10).responses[0]
151+
long_resp = future.result().responses[0]
152+
batched_4 = [long_resp.text, short_resp.text]
153+
log_result(batched_4, individual_4)

0 commit comments

Comments
 (0)