@@ -34,7 +34,7 @@ def get_tokenizer(
3434def get_random_length (reqs_num : int , length : int , range_ratio : float ) -> List [int ]:
3535 lens = []
3636 lens = np .random .randint (
37- int (length * range_ratio ),
37+ max ( int (length * range_ratio ), 1 ),
3838 length + 1 ,
3939 size = reqs_num ,
4040 )
@@ -64,6 +64,21 @@ def gen_random_data(
6464 return prompts , output_lens
6565
6666
67+ def get_custom_input_data (data_path , output_len , tokenizer , range_ratio ):
68+ prompts = []
69+ with open (data_path , "r" ) as f :
70+ for line in f .readlines ():
71+ data_line = json .loads (line )
72+ input_data = tokenizer .apply_chat_template (
73+ data_line ["messages" ], add_generation_prompt = True , tokenize = False
74+ )
75+ input_len = len (tokenizer .encode (input_data ))
76+ prompts .append ([input_data , input_len ])
77+ output_lens = get_random_length (len (prompts ), output_len , range_ratio )
78+ print ("Load random data finish." )
79+ return prompts , output_lens
80+
81+
6782model_name = []
6883
6984
@@ -115,13 +130,13 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session):
115130 "do_sample" : False ,
116131 "ignore_eos" : True ,
117132 "max_new_tokens" : max_new_tokens ,
133+ "add_special_tokens" : False ,
118134 },
119135 }
120136 headers = {"Content-Type" : "application/json" }
121137 used_time = []
122138 start_time = time .time ()
123139 last_time = start_time
124-
125140 async with session .post (url , headers = headers , json = data ) as response :
126141 if response .status != 200 :
127142 return []
@@ -189,12 +204,11 @@ async def response_collector(
189204 result , input_len = await task
190205 request_queue .task_done ()
191206 assert result is not None
192- if len (result ) > 1 and not stop_send .is_set ():
207+ if len (result ) >= 1 and not stop_send .is_set ():
193208 results .append ((result , input_len ))
194209 current_count = counter [0 ] + 1
195210 counter [0 ] = current_count
196211 print (f"\r finished_reqs:{ current_count } / target_reqs:{ reqs_num } / sent_reqs:{ sent_count [0 ]} " , end = "" )
197-
198212 if len (results ) >= reqs_num and not stop_send .is_set ():
199213 end_time [0 ] = time .time ()
200214 print ("\n Reached target number of responses" )
@@ -292,6 +306,7 @@ def main():
292306 )
293307 parser .add_argument ("--num_clients" , type = int , default = 100 )
294308 parser .add_argument ("--tokenizer_path" , type = str , default = None )
309+ parser .add_argument ("--data_path" , type = str , default = None )
295310 parser .add_argument ("--input_num" , type = int , default = 2000 )
296311 parser .add_argument ("--input_qps" , type = float , default = 30.0 )
297312 parser .add_argument ("--input_len" , type = int , default = 1024 )
@@ -323,17 +338,21 @@ def main():
323338
324339 assert args .tokenizer_path is not None
325340 model_name .append (args .tokenizer_path )
326- seed_all (args .seed )
341+ # seed_all(args.seed)
327342 url = args .url
328343 tokenizer = get_tokenizer (args .tokenizer_path )
329- # qps发送模式发送请求的数量不固定,这里暂定为input_num的10倍
330- prompts , max_new_tokens = gen_random_data (
331- args .input_len ,
332- args .output_len ,
333- args .input_num if not args .continuous_send else 10 * args .input_num ,
334- tokenizer ,
335- args .range_ratio ,
336- )
344+ if args .data_path is not None :
345+ prompts , max_new_tokens = get_custom_input_data (args .data_path , args .output_len , tokenizer , args .range_ratio )
346+ args .input_num = len (prompts )
347+ else :
348+ # qps发送模式发送请求的数量不固定,这里暂定为input_num的10倍
349+ prompts , max_new_tokens = gen_random_data (
350+ args .input_len ,
351+ args .output_len ,
352+ args .input_num if not args .continuous_send else 10 * args .input_num ,
353+ tokenizer ,
354+ args .range_ratio ,
355+ )
337356
338357 percentiles = [25 , 50 , 75 , 90 , 95 , 99 , 100 ]
339358 if args .server_api == "lightllm" :
@@ -364,7 +383,7 @@ def main():
364383 )
365384 )
366385 loop .close ()
367-
386+ print ( len ( results ))
368387 first_token_time = []
369388 decode_token_time = []
370389 request_time = []
@@ -379,6 +398,13 @@ def main():
379398 final_output_lens .append (len (result ))
380399 input_lens .append (input_len )
381400 valid_num += 1
401+ else :
402+ first_token_time .append (result [0 ])
403+ decode_token_time .append (0 ) # no decode
404+ request_time .append (sum (result ))
405+ final_output_lens .append (len (result ))
406+ input_lens .append (input_len )
407+ valid_num += 1
382408
383409 print (
384410 f"\n \n valid num = { valid_num } ; all data num = { len (results )} ; valid ratio = { valid_num * 1.0 / len (results )} \n "
0 commit comments