@@ -329,6 +329,27 @@ def _broadcast(item):
329329
330330 groups , sample_id_groups = scheduler .get_groups_and_subsamples (global_id_seqlens , config )
331331
332+ # #debugmtl
333+ # if parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0:
334+ # k = 0
335+ # for group in sample_id_groups:
336+ # print(f"group {k}: ",end="")
337+ # for i in range(len(group)):
338+ # print(f"GPU-{i}: [",end="")
339+ # for j in range(len(group[i])):
340+ # print(f"{group[i][j]}-{global_id_seqlens[group[i][j]][1]}, ",end=" ")
341+ # print(f"], ")
342+ # k += 1
343+ # print()
344+
345+ # debugmtl
346+ # set_gbs = set()
347+ # for group in sample_id_groups:
348+ # for sub in group:
349+ # set_gbs.update(sub)
350+ # assert len(set_gbs) == len(global_id_seqlens),
351+ # f"set_gbs length: {len(set_gbs)} != global_ids_this_rank length: {len(global_id_seqlens)}"
352+
332353 batch = _unpack_batch (batch )
333354 samples_this_rank_with_id = _reroute_samples_to_hdp_ranks (
334355 batch ,
@@ -415,7 +436,9 @@ def _pack_tensors(tensors):
415436 new_sample ["cu_seqlens" ] = cu_seqlens
416437
417438 new_samples .append (new_sample )
418-
439+ # #debugmtl
440+ # print(f"rank {parallel_state.get_data_parallel_rank
441+ # (with_context_parallel=True)} new_samples length: {len(new_samples)}")
419442 new_data_iterator = RerunDataIterator (iter (new_samples ))
420443
421444 return (
@@ -460,15 +483,28 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
460483 sum_seqlen = 0
461484 single_microbatch = []
462485
486+ # debugmtl use 1 seq per microbatch
463487 for i in range (len (sample_id_seqlens )):
464- if sum_seqlen + sample_id_seqlens [i ][1 ] <= self .max_seq_len_all_ranks :
465- single_microbatch .append (i )
466- sum_seqlen += sample_id_seqlens [i ][1 ]
467- else :
468- groups .append (single_microbatch )
469- packed_id_groups .append (single_microbatch )
470- single_microbatch = [i ]
471- sum_seqlen = sample_id_seqlens [i ][1 ]
488+ packed_id_groups .append ([i ])
489+
490+ # for i in range(len(sample_id_seqlens)):
491+ # if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks:
492+ # single_microbatch.append(i)
493+ # sum_seqlen += sample_id_seqlens[i][1]
494+ # else:
495+ # packed_id_groups.append(single_microbatch)
496+ # single_microbatch = [i]
497+ # sum_seqlen = sample_id_seqlens[i][1]
498+ # if len(single_microbatch) > 0:
499+ # packed_id_groups.append(single_microbatch)
500+
501+ # debugmtl
502+ gbs_sum = 0
503+ for i in packed_id_groups :
504+ gbs_sum += len (i )
505+ assert gbs_sum == len (
506+ sample_id_seqlens
507+ ), f"gbs_sum: { gbs_sum } != sample_id_seqlens length: { len (sample_id_seqlens )} "
472508
473509 # we want the number of packed sequences to be multiple of dp_size
474510 # so we move few samples from previous microbatch
@@ -482,7 +518,7 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
482518 assert i > 0 , "Not enough samples to move"
483519 if len (packed_id_groups [i ]) > 1 :
484520 seq_id = packed_id_groups [i ].pop ()
485- packed_id_groups [ i ] .append (seq_id )
521+ packed_id_groups .append ([ seq_id ] )
486522 num_to_move -= 1
487523 else :
488524 i -= 1
@@ -493,7 +529,9 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
493529 for j in range (self .cp_size * self .dp_size ):
494530 seq_id = int (i * self .dp_size + j / self .cp_size )
495531 sample_id_groups [i ].append (packed_id_groups [seq_id ])
496-
532+ # debugmtl
533+ # print(f"rank {parallel_state.get_data_parallel_rank(with_context_parallel=True)} \
534+ # sample_id_groups: {len(sample_id_groups)}")
497535 return groups , sample_id_groups
498536
499537
0 commit comments