88
99import torch
1010import torch .distributed as dist
11+ import time
12+ import os
13+ import json
1114
1215from internlm .core .context import ParallelMode
1316from internlm .core .context import global_context as gpc
@@ -202,22 +205,39 @@ def _call_engine(engine, data): # pylint: disable=W0237
202205 def load_batch (self , engine , data_iter ):
203206 # Pipeline schedule just puts data in memory,
204207 batch_data , actual_batch_size = engine .load_batch (data_iter , to_gpu = False )
205-
208+ batch_seqlist = []
209+ # import pdb
210+ # pdb.set_trace()
206211 # Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed,
207212 # because internlm's current train dataset is packed, even using dummy data.
208213 # The unpack operation is performed in load_micro_batch().
209214 if check_data_is_packed (batch_data ):
210215 micro_num = actual_batch_size
211216 else :
212217 micro_num = actual_batch_size // gpc .config .data ["micro_bsz" ]
213-
218+ # import pdb
219+ # breakpoint()
220+ for micro_batch_cu in batch_data [0 ]['cu_seqlens' ]:
221+ micro_batch_seqlist = [ int (micro_batch_cu [j ]) - int (micro_batch_cu [j - 1 ]) for j in range (1 , len (micro_batch_cu ))]
222+ batch_seqlist .append (micro_batch_seqlist )
223+
214224 self .microbatch_offset = 0
215225 self .batch_size = actual_batch_size
216226 self .batch_data , self .batch_label = batch_data
217227 self .bsz_stride = self .batch_size // micro_num
218228 # 'num_microbatches' is no longer an initialization parameter,
219229 # but is determined on the fly by the Scheduler.
220230 self .num_microbatches = micro_num # Rampup or variable bsz size.
231+
232+ if gpc .config .profile_fwd_bwd and os .environ .get ("CUDA_LAUNCH_BLOCKING" ) == "1" and gpc .get_local_rank (ParallelMode .DATA ) == 0 and gpc .get_local_rank (ParallelMode .TENSOR ) == 0 :
233+ output_dir = os .path .join ("./micro_record" , gpc .config .data .data_name , f"B{ gpc .config .data .bucket_size } _seq{ gpc .config .SEQ_LEN } _mb{ gpc .config .data .micro_num } " , f'S{ gpc .batch_count } ' )
234+ os .makedirs (output_dir , exist_ok = True )
235+ output_file = os .path .join (output_dir , f"PP_rank_{ gpc .get_local_rank (ParallelMode .PIPELINE )} _seq.json" )
236+
237+ with open (output_file , "w" ) as f :
238+ for micro_batch_seqlist in batch_seqlist :
239+ json .dump (micro_batch_seqlist , f )
240+ f .write ('\n ' )
221241
222242 def load_micro_batch (self ):
223243 micro_batch_data , micro_batch_label = self ._load_micro_batch (
@@ -592,8 +612,12 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
592612 input_obj = None
593613
594614 # Run 1F1B in steady state.
615+ fwd_times = []
616+ bwd_times = []
617+
595618 for i in range (num_1f1b_micropairs ):
596619 # Perform forward computation
620+ start_time = time .time ()
597621 output_obj , moe_loss = self ._forward_step (
598622 engine ,
599623 input_obj ,
@@ -602,6 +626,7 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
602626 accum_loss = accum_loss ,
603627 accum_moe_loss = accum_moe_loss ,
604628 )
629+ fwd_times .append (time .time () - start_time )
605630
606631 if gpc .is_last_rank (ParallelMode .PIPELINE ):
607632 output_obj_grad = None
@@ -625,7 +650,9 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
625650 output_obj = output_objs .pop (0 )
626651 moe_loss = moe_losses .pop (0 )
627652
653+ start_bwd_time = time .time ()
628654 input_obj_grad = self ._backward_step (engine , i , input_obj , output_obj , output_obj_grad , moe_loss )
655+ bwd_times .append (time .time () - start_bwd_time )
629656
630657 if i == (num_1f1b_micropairs - 1 ):
631658 input_obj = None
@@ -644,6 +671,44 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
644671 dtype = self .dtype ,
645672 scatter_gather_tensors = self .scatter_gather_tensors ,
646673 )
674+ if gpc .config .profile_fwd_bwd and os .environ .get ("CUDA_LAUNCH_BLOCKING" ) == "1" and gpc .get_local_rank (ParallelMode .DATA ) == 0 and gpc .get_local_rank (ParallelMode .TENSOR ) == 0 :
675+ output_dir = os .path .join ("./micro_record" , gpc .config .data .data_name , f"B{ gpc .config .data .bucket_size } _seq{ gpc .config .SEQ_LEN } _mb{ gpc .config .data .micro_num } " , f'S{ gpc .batch_count } ' )
676+ os .makedirs (output_dir , exist_ok = True )
677+ output_file = os .path .join (output_dir , f"PP_rank_{ gpc .get_local_rank (ParallelMode .PIPELINE )} .json" )
678+ gpc .batch_count += 1
679+
680+ history = {
681+ "fwd_times" : [],
682+ "bwd_times" : [],
683+ }
684+
685+ # 2. 如果文件存在,则读取旧数据
686+ if os .path .exists (output_file ):
687+ with open (output_file , 'r' ) as f :
688+ try :
689+ history = json .load (f )
690+ except json .JSONDecodeError :
691+ pass # 文件为空或损坏则跳过
692+
693+ # 3. 追加新数据
694+ history ["fwd_times" ].extend (fwd_times )
695+ history ["bwd_times" ].extend (bwd_times )
696+
697+ from collections import OrderedDict
698+ data = OrderedDict ()
699+ # 4. 更新平均值
700+ data ["avg_fwd" ] = sum (history ["fwd_times" ]) / len (history ["fwd_times" ])
701+ data ["avg_bwd" ] = sum (history ["bwd_times" ]) / len (history ["bwd_times" ])
702+ f_f = round (data ["avg_fwd" ]/ data ["avg_fwd" ],3 )
703+ b_f = round (data ["avg_bwd" ]/ data ["avg_fwd" ],3 )
704+ data ["f_b_w" ] = (f_f , b_f )
705+ data ["fwd_times" ] = history ["fwd_times" ]
706+ data ["bwd_times" ] = history ["bwd_times" ]
707+
708+ # 5. 写回文件
709+ with open (output_file , 'w' ) as f :
710+ json .dump (data , f , indent = 4 )
711+
647712
648713 # Run cooldown backward passes.
649714 for i in range (num_warmup_microsteps ):
0 commit comments