88from lightllm .utils .envs_utils import get_env_start_args
99from lightllm .models .deepseek2 .model import Deepseek2TpPartModel
1010from lightllm .common .basemodel .microbatch_overlap_objs import DecodeMicroBatch
11+ from torch .profiler import profile , record_function , ProfilerActivity
1112
1213
1314def test_model_inference (args , model_class ):
@@ -116,6 +117,16 @@ def decode(
116117 return logits
117118
118119
120+ def torch_profile (fn , log_dir = None ):
121+ with profile (
122+ activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ],
123+ record_shapes = False ,
124+ on_trace_ready = torch .profiler .tensorboard_trace_handler (log_dir )
125+ ) as prof :
126+ fn ()
127+ print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
128+
129+
119130def tppart_model_infer (args , model_class , model_kvargs , batch_size , input_len , output_len , ans_queue ):
120131 args = get_env_start_args ()
121132 import triton .profiler as proton
@@ -244,6 +255,28 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
244255 if args .profile :
245256 proton .start (name = "forward_prefill" , context = "python" )
246257
258+ if args .torch_profile :
259+ print ("Profile Prefill" )
260+ try :
261+ torch_profile (
262+ lambda : model_part .forward (
263+ batch_size ,
264+ total_token_num ,
265+ input_len ,
266+ test_data ,
267+ mem_indexes ,
268+ b_req_idx ,
269+ b_start_loc ,
270+ b_seq_len ,
271+ b_ready_cache_len = b_ready_cache_len ,
272+ is_prefill = True ,
273+ ),
274+ log_dir = f"./logs_decode_overlap/forward_prefill_{ model_kvargs ['rank_id' ]} " ,
275+ )
276+ except Exception as e :
277+ print (str (e ))
278+ raise
279+
247280 logics = model_part .forward (
248281 batch_size ,
249282 total_token_num ,
@@ -291,6 +324,21 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
291324 b_seq_len ,
292325 total_token_num ,
293326 )
327+ if i == 0 and args .torch_profile :
328+ torch_profile (
329+ lambda : overlap_decode (
330+ model_part ,
331+ batch_size ,
332+ max_len_in_batch ,
333+ torch .from_numpy (predict_ids ).cuda ().reshape (- 1 ),
334+ mem_indexes ,
335+ b_req_idx ,
336+ b_start_loc ,
337+ b_seq_len ,
338+ total_token_num ,
339+ ),
340+ log_dir = f"./logs_decode_overlap/forward_decode_{ model_kvargs ['rank_id' ]} " ,
341+ )
294342 else :
295343 logits = decode (
296344 model_part ,
@@ -303,6 +351,21 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
303351 b_seq_len ,
304352 total_token_num ,
305353 )
354+ if i == 0 and args .torch_profile :
355+ torch_profile (
356+ lambda : decode (
357+ model_part ,
358+ batch_size ,
359+ max_len_in_batch ,
360+ torch .from_numpy (predict_ids ).cuda ().reshape (- 1 ),
361+ mem_indexes ,
362+ b_req_idx ,
363+ b_start_loc ,
364+ b_seq_len ,
365+ total_token_num ,
366+ ),
367+ log_dir = f"./logs_decode_overlap/forward_decode_{ model_kvargs ['rank_id' ]} " ,
368+ )
306369
307370 prob_out = torch .softmax (logits , dim = - 1 )
308371 predict_ids = torch .argmax (prob_out , dim = 1 , keepdim = True )
0 commit comments