@@ -75,7 +75,7 @@ def test_model_inference_mtp(args):
7575 "graph_max_batch_size" : args .graph_max_batch_size ,
7676 "mem_faction" : args .mem_fraction ,
7777 "max_req_num" : 2000 ,
78- "batch_max_tokens" : 16384 ,
78+ "batch_max_tokens" : 2048 ,
7979 "run_mode" : "normal" ,
8080 "max_seq_length" : args .max_req_total_len ,
8181 "spec_algo" : args .spec_algo ,
@@ -110,7 +110,7 @@ def torch_profile(fn, log_dir=None):
110110 print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
111111
112112
113- def run_forward_once (input_len , output_len , batch_size , main_model , draft_models , warmup = False ):
113+ def run_forward_once (args , input_len , output_len , batch_size , main_model , draft_models , warmup = False ):
114114 import time
115115
116116 torch .cuda .synchronize ()
@@ -166,7 +166,9 @@ def run_forward_once(input_len, output_len, batch_size, main_model, draft_models
166166 prefill_end_time = time .time ()
167167 if get_current_rank_in_dp () == 0 and not warmup :
168168 print ("prefill time cost:" , (prefill_end_time - prefill_start_time ) * 1000 )
169- print (f"Prefill throughput: { batch_size * input_len / (prefill_end_time - prefill_start_time )} tokens/s" )
169+ print (
170+ f"Prefill throughput: { batch_size * input_len * args .dp / (prefill_end_time - prefill_start_time )} tokens/s"
171+ )
170172
171173 torch .cuda .synchronize ()
172174
@@ -240,7 +242,7 @@ def run_forward_once(input_len, output_len, batch_size, main_model, draft_models
240242 if get_current_rank_in_dp () == 0 and not warmup :
241243 step_time = step_end_time - step_start_time
242244 print (i , " step cost time:" , step_time * 1000 )
243- print (f"Decode throughput: { batch_size * (len (draft_models ) + 1 ) / step_time } tokens/s" )
245+ print (f"Decode throughput: { batch_size * (len (draft_models ) + 1 ) * args . dp / step_time } tokens/s" )
244246
245247 main_model .mem_manager .free_all ()
246248 main_model .req_manager .free_all ()
@@ -273,9 +275,9 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a
273275
274276 for batch_size in batch_sizes :
275277 # warm up
276- run_forward_once (input_len , output_len , batch_size , main_model , draft_models , warmup = True )
278+ run_forward_once (args , input_len , output_len , batch_size , main_model , draft_models , warmup = True )
277279 torch .cuda .synchronize ()
278- run_forward_once (input_len , output_len , batch_size , main_model , draft_models , warmup = False )
280+ run_forward_once (args , input_len , output_len , batch_size , main_model , draft_models , warmup = False )
279281 dist .barrier ()
280282
281283 ans_queue .put (True )
0 commit comments