1010import torch
1111
1212from megatron .core import parallel_state
13+
14+ # from megatron.core.pipeline_parallel.utils import (
15+ # is_pp_first_stage,
16+ # is_pp_last_stage,
17+ # is_vp_first_stage,
18+ # is_vp_last_stage,
19+ # )
1320from megatron .core .process_groups_config import ProcessGroupCollection
1421from megatron .core .rerun_state_machine import RerunDataIterator
1522
@@ -293,17 +300,24 @@ def _broadcast(item):
293300 dp_cp_group = parallel_state .get_data_parallel_group (with_context_parallel = True )
294301 dp_group = parallel_state .get_data_parallel_group ()
295302 tp_group = parallel_state .get_tensor_model_parallel_group ()
303+ pp_group = parallel_state .get_pipeline_model_parallel_group ()
296304 else :
297305 dp_cp_group = pg_collection .dp_cp
298306 dp_group = pg_collection .dp
299307 tp_group = pg_collection .tp
308+ pp_group = pg_collection .pp
300309 assert (
301310 dp_cp_group is not None and dp_group is not None and tp_group is not None
302311 ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel"
303312
304313 total_hdp_gpus = dp_cp_group .size ()
305314 dev = torch .cuda .current_device ()
306315
316+ # if is_pp_first_stage(pp_group) or is_pp_last_stage(pp_group) and tp_group.rank() == 0:
317+ # # do what data_iterator is doing
318+
319+ # # first stage tp-0 broadcast num_micro_batches cu_seqlens to
320+
307321 if data_iterator is None :
308322 # TP-0 reads from data_iterator, others receive via broadcast.
309323 sample_id_groups , batch = None , None
@@ -329,6 +343,16 @@ def _broadcast(item):
329343
330344 groups , sample_id_groups = scheduler .get_groups_and_subsamples (global_id_seqlens , config )
331345
346+ # debugmtl
347+ set_gbs = set ()
348+ for group in sample_id_groups :
349+ for sub in group :
350+ set_gbs .update (sub )
351+ assert len (set_gbs ) == len (
352+ global_id_seqlens
353+ ), f"set_gbs length: { len (set_gbs )} \
354+ != global_ids_this_rank length: { len (global_id_seqlens )} "
355+
332356 batch = _unpack_batch (batch )
333357 samples_this_rank_with_id = _reroute_samples_to_hdp_ranks (
334358 batch ,
@@ -384,9 +408,10 @@ def _pack_tensors(tensors):
384408 new_sample ["labels" ] = labels
385409 new_sample ["loss_mask" ] = loss_mask
386410 new_sample ["position_ids" ] = position_ids
387- new_sample ["local_cp_size" ] = torch .tensor (
388- partner_cp_size , dtype = torch .int32 , device = dev
389- )
411+ if scheduler_type is PackingScheduler .HYBRID_CP :
412+ new_sample ["local_cp_size" ] = torch .tensor (
413+ partner_cp_size , dtype = torch .int32 , device = dev
414+ )
390415
391416 # create cu_seqlens_padded
392417 lengths_padding = np .fromiter (
@@ -415,7 +440,9 @@ def _pack_tensors(tensors):
415440 new_sample ["cu_seqlens" ] = cu_seqlens
416441
417442 new_samples .append (new_sample )
418-
443+ # #debugmtl
444+ # print(f"rank {parallel_state.get_data_parallel_rank
445+ # (with_context_parallel=True)} new_samples length: {len(new_samples)}")
419446 new_data_iterator = RerunDataIterator (iter (new_samples ))
420447
421448 return (
@@ -460,15 +487,30 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
460487 sum_seqlen = 0
461488 single_microbatch = []
462489
490+ # # debugmtl use 1 seq per microbatch
491+ # num_micro_batches = len(sample_id_seqlens)//self.dp_size
492+ # for i in range(num_micro_batches):
493+ # for j in range(self.dp_size):
494+ # packed_id_groups.append([i+j*num_micro_batches])
495+
463496 for i in range (len (sample_id_seqlens )):
464497 if sum_seqlen + sample_id_seqlens [i ][1 ] <= self .max_seq_len_all_ranks :
465498 single_microbatch .append (i )
466499 sum_seqlen += sample_id_seqlens [i ][1 ]
467500 else :
468- groups .append (single_microbatch )
469501 packed_id_groups .append (single_microbatch )
470502 single_microbatch = [i ]
471503 sum_seqlen = sample_id_seqlens [i ][1 ]
504+ if len (single_microbatch ) > 0 :
505+ packed_id_groups .append (single_microbatch )
506+
507+ # debugmtl
508+ gbs_sum = 0
509+ for i in packed_id_groups :
510+ gbs_sum += len (i )
511+ assert gbs_sum == len (
512+ sample_id_seqlens
513+ ), f"gbs_sum: { gbs_sum } != sample_id_seqlens length: { len (sample_id_seqlens )} "
472514
473515 # we want the number of packed sequences to be multiple of dp_size
474516 # so we move few samples from previous microbatch
@@ -482,7 +524,7 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
482524 assert i > 0 , "Not enough samples to move"
483525 if len (packed_id_groups [i ]) > 1 :
484526 seq_id = packed_id_groups [i ].pop ()
485- packed_id_groups [ i ] .append (seq_id )
527+ packed_id_groups .append ([ seq_id ] )
486528 num_to_move -= 1
487529 else :
488530 i -= 1
@@ -493,7 +535,9 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
493535 for j in range (self .cp_size * self .dp_size ):
494536 seq_id = int (i * self .dp_size + j / self .cp_size )
495537 sample_id_groups [i ].append (packed_id_groups [seq_id ])
496-
538+ # debugmtl
539+ # print(f"rank {parallel_state.get_data_parallel_rank(with_context_parallel=True)} \
540+ # sample_id_groups: {len(sample_id_groups)}")
497541 return groups , sample_id_groups
498542
499543
0 commit comments