|
| 1 | +""" |
| 2 | +
|
| 3 | +This script can be used to help verify correctness of models running in TGIS |
| 4 | +with dynamic batching. |
| 5 | +
|
| 6 | +False negatives are still possible since there is some amount of inconsistency expected due to |
| 7 | +the fixed precision floating point operations, for example in float16 and especially bfloat16. |
| 8 | +
|
| 9 | +""" |
| 10 | +import concurrent.futures |
| 11 | +import time |
| 12 | + |
| 13 | +import grpc |
| 14 | +import generation_pb2 |
| 15 | +import generation_pb2_grpc |
| 16 | + |
| 17 | + |
| 18 | +class col: |
| 19 | + RED = '\033[31m' |
| 20 | + ENDC = '\033[m' |
| 21 | + GREEN = '\033[32m' |
| 22 | + YELLOW = '\033[33m' |
| 23 | + BLUE = '\033[34m' |
| 24 | + |
| 25 | + |
| 26 | +if __name__ == "__main__": |
| 27 | + channel = grpc.insecure_channel("localhost:8033") |
| 28 | + stub = generation_pb2_grpc.GenerationServiceStub(channel) |
| 29 | + |
| 30 | + # 1) Common input + output length batch |
| 31 | + # -- Based on this determine unique stop sequence |
| 32 | + # 2) Variable input length, common output length batch |
| 33 | + # 3) Common input length, variable output length batch (using stop sequence from 1) |
| 34 | + # 4) Concatenation test, long, with short concurrently after tiny delay |
| 35 | + |
| 36 | + model_id = "unused" |
| 37 | + |
| 38 | + out_length = 80 # should be at least 50 |
| 39 | + long_text = ( |
| 40 | + "The core components include: a studio for new foundation models, generative AI and machine learning; " |
| 41 | + "a fit-for-purpose data store built on an open data lakehouse architecture; and a toolkit, to accelerate " |
| 42 | + "AI workflows that are built with responsibility, transparency and explainability." |
| 43 | + ) |
| 44 | + short_text = "watsonx is a generative AI and data platform with a set of AI assistants." |
| 45 | + |
| 46 | + # Get token counts for inputs |
| 47 | + tresp = stub.Tokenize(generation_pb2.BatchedTokenizeRequest( |
| 48 | + model_id=model_id, |
| 49 | + requests=[generation_pb2.TokenizeRequest(text=text) for text in (long_text, short_text)], |
| 50 | + return_tokens=False, |
| 51 | + )) |
| 52 | + |
| 53 | + long_tokens, short_tokens = (tr.token_count for tr in tresp.responses) |
| 54 | + |
| 55 | + print(f"token counts: short={short_tokens}, long={long_tokens}") |
| 56 | + |
| 57 | + assert 15 < short_tokens < long_tokens |
| 58 | + |
| 59 | + truncate_to = short_tokens - 10 |
| 60 | + |
| 61 | + def send_requests(requests, truncate_to=0, stop_seqs=None, out_length=out_length, min_out_length=None): |
| 62 | + if min_out_length is None: |
| 63 | + min_out_length = out_length |
| 64 | + |
| 65 | + return stub.Generate(generation_pb2.BatchedGenerationRequest( |
| 66 | + model_id=model_id, |
| 67 | + requests=requests, |
| 68 | + params=generation_pb2.Parameters( |
| 69 | + truncate_input_tokens=truncate_to, |
| 70 | + stopping=generation_pb2.StoppingCriteria( |
| 71 | + min_new_tokens=min_out_length, |
| 72 | + max_new_tokens=out_length, |
| 73 | + stop_sequences=stop_seqs, |
| 74 | + ), |
| 75 | + ), |
| 76 | + )) |
| 77 | + |
| 78 | + def log_result(batched_results, individual_results): |
| 79 | + matches = [x == y for (x, y) in zip(batched_results, individual_results)] |
| 80 | + if all(matches): |
| 81 | + print(col.GREEN + "PASS" + col.ENDC) |
| 82 | + else: |
| 83 | + print(f"{col.RED}FAIL{col.ENDC}: {matches.count(False)}/{len(matches)} mismatches: {matches}") |
| 84 | + #TODO improve how these diffs are printed |
| 85 | + print(f"BATCHED: {batched_results}") |
| 86 | + print(f"SINGLE : {individual_results}") |
| 87 | + |
| 88 | + |
| 89 | + #TODO first do single request consistency |
| 90 | + |
| 91 | + #TODO can add a test here to check invariance to front-padding |
| 92 | + |
| 93 | + |
| 94 | + greqs = [generation_pb2.GenerationRequest(text=text) for text in (long_text, short_text)] |
| 95 | + |
| 96 | + |
| 97 | + ### Test 1 ################################################################################# |
| 98 | + ############################################################################################ |
| 99 | + test_name = "Common input and output sizes (tests basic batching)" |
| 100 | + print(f"\n{test_name}") |
| 101 | + batched_1 = [gr.text for gr in send_requests(greqs, truncate_to).responses] |
| 102 | + individual_1 = [send_requests([gr], truncate_to).responses[0].text for gr in greqs] |
| 103 | + log_result(batched_1, individual_1) |
| 104 | + |
| 105 | + |
| 106 | + ### Test 2 ################################################################################# |
| 107 | + ############################################################################################ |
| 108 | + test_name = "Variable input lengths, common output length (tests padded batch)" |
| 109 | + print(f"\n{test_name}") |
| 110 | + batched_2 = [gr.text for gr in send_requests(greqs).responses] |
| 111 | + individual_2 = [send_requests([gr]).responses[0].text for gr in greqs] |
| 112 | + log_result(batched_2, individual_2) |
| 113 | + |
| 114 | + |
| 115 | + ### Test 3 ################################################################################# |
| 116 | + ############################################################################################ |
| 117 | + # Find stop seq - a string in one of the output sequences that isn't in the other |
| 118 | + ss = None |
| 119 | + for i in range(20, out_length - 5): |
| 120 | + ss = batched_1[0][i:i+5] |
| 121 | + if ss not in batched_1[1]: |
| 122 | + break |
| 123 | + |
| 124 | + if ss is None: |
| 125 | + print("\ncouldn't find stop seq to use for variable output length test") |
| 126 | + else: |
| 127 | + test_name = "Common input length, variable output lengths (tests batch pruning)" |
| 128 | + print(f"\n{test_name}") |
| 129 | + print(f"Using stop-sequence '{ss}'") |
| 130 | + batched_3 = [gr.text for gr in send_requests(greqs, truncate_to, [ss], min_out_length=1).responses] |
| 131 | + individual_3 = [ |
| 132 | + send_requests([gr], truncate_to, [ss], min_out_length=1).responses[0].text for gr in greqs |
| 133 | + ] |
| 134 | + log_result(batched_3, individual_3) |
| 135 | + |
| 136 | + |
| 137 | + ### Test 4 ################################################################################# |
| 138 | + ############################################################################################ |
| 139 | + test_name = "Short output interrupting long output (tests batch concatenation)" |
| 140 | + print(f"\n{test_name}") |
| 141 | + long_input_req, short_input_req = greqs |
| 142 | + individual_4 = [ |
| 143 | + send_requests([short_input_req], out_length=100).responses[0].text, |
| 144 | + send_requests([long_input_req], out_length=10).responses[0].text, |
| 145 | + ] |
| 146 | + |
| 147 | + with concurrent.futures.ThreadPoolExecutor(1) as executor: |
| 148 | + future = executor.submit(send_requests, [short_input_req], out_length=100) |
| 149 | + time.sleep(0.25) |
| 150 | + short_resp = send_requests([long_input_req], out_length=10).responses[0] |
| 151 | + long_resp = future.result().responses[0] |
| 152 | + batched_4 = [long_resp.text, short_resp.text] |
| 153 | + log_result(batched_4, individual_4) |
0 commit comments