@@ -71,6 +71,21 @@ def get_token_throughput_latencies(
7171 req_launcher = RequestsLauncher (clients )
7272 completed_requests = []
7373 num_completed_requests = 0
74+ # make up prompts outside of send loop for faster benchmarking loop
75+ num_output_tokens_list = []
76+ prompts = []
77+ for i in range (max_num_completed_requests ):
78+ num_output_tokens = (sample_random_positive_int (
79+ mean_output_tokens , stddev_output_tokens
80+ ))
81+ num_output_tokens_list .append (num_output_tokens )
82+
83+ prompts .append (randomly_sample_sonnet_lines_prompt (
84+ prompt_tokens_mean = mean_input_tokens ,
85+ prompt_tokens_stddev = stddev_input_tokens ,
86+ expect_output_tokens = num_output_tokens ,
87+ tokenizer = tokenizer
88+ ))
7489 start_time = time .monotonic ()
7590 iter = 0
7691 pbar = tqdm (total = max_num_completed_requests )
@@ -79,21 +94,12 @@ def get_token_throughput_latencies(
7994 and len (completed_requests ) < max_num_completed_requests
8095 ):
8196 iter += 1
82- num_output_tokens = sample_random_positive_int (
83- mean_output_tokens , stddev_output_tokens
84- )
85-
86- prompt = randomly_sample_sonnet_lines_prompt (
87- prompt_tokens_mean = mean_input_tokens ,
88- prompt_tokens_stddev = stddev_input_tokens ,
89- expect_output_tokens = num_output_tokens ,
90- )
9197
92- default_sampling_params = {"max_tokens" : num_output_tokens }
98+ default_sampling_params = {"max_tokens" : num_output_tokens_list . pop () }
9399 default_sampling_params .update (additional_sampling_params )
94100 request_config = RequestConfig (
95101 model = model ,
96- prompt = prompt ,
102+ prompt = prompts . pop () ,
97103 sampling_params = default_sampling_params ,
98104 llm_api = llm_api ,
99105 )
0 commit comments