@@ -307,10 +307,9 @@ def _broadcast(item):
307307 total_hdp_gpus = dp_cp_group .size ()
308308 dev = torch .cuda .current_device ()
309309
310- # TODO(tailaim): handle the case when data_iterator is None
311310 if data_iterator is None :
312311 # TP-0 reads from data_iterator, others receive via broadcast.
313- sample_id_groups , batch = None , None , None
312+ sample_id_groups , batch = None , None
314313 num_total_groups_broadcast = torch .tensor ([0 ], dtype = torch .int32 , device = dev )
315314 _broadcast (num_total_groups_broadcast )
316315 num_micro_batches = int (num_total_groups_broadcast .item ())
@@ -352,6 +351,10 @@ def _broadcast(item):
352351 )
353352 _broadcast (num_total_groups_broadcast )
354353
354+ # TODO(tailaim): calculate this two values properly
355+ # num_total_tokens_this_GA = losses_reduced.pop(0)
356+ # sequence_square_sum_this_GA = losses_reduced.pop(0)
357+
355358 # pack sequences in the same group and create a new data iterator
356359 new_samples = []
357360 for i in range (num_micro_batches ):
@@ -375,10 +378,10 @@ def _pack_tensors(tensors):
375378
376379 # TODO(tailaim): do we need attention_mask for sequence packing?
377380 new_sample = {}
378- new_sample ["tokens" ] = tokens . unsqueeze ( 0 )
379- new_sample ["labels" ] = labels . unsqueeze ( 0 )
380- new_sample ["loss_mask" ] = loss_mask . unsqueeze ( 0 )
381- new_sample ["position_ids" ] = position_ids . unsqueeze ( 0 )
381+ new_sample ["tokens" ] = tokens
382+ new_sample ["labels" ] = labels
383+ new_sample ["loss_mask" ] = loss_mask
384+ new_sample ["position_ids" ] = position_ids
382385 new_sample ["local_cp_size" ] = torch .tensor (
383386 partner_cp_size , dtype = torch .int32 , device = dev
384387 )
@@ -442,6 +445,7 @@ def __init__(self, config):
442445 super ().__init__ (config )
443446 self .max_seq_len_all_ranks = config .max_seqlen_per_dp_cp_rank * config .context_parallel_size
444447 self .dp_size = parallel_state .get_data_parallel_world_size ()
448+ self .cp_size = parallel_state .get_context_parallel_world_size ()
445449
446450 def get_groups_and_subsamples (self , sample_id_seqlens , config ):
447451 """
@@ -451,35 +455,44 @@ def get_groups_and_subsamples(self, sample_id_seqlens, config):
451455 """
452456 groups = []
453457 sample_id_groups = []
458+ packed_id_groups = []
454459 sum_seqlen = 0
455460 single_microbatch = []
456461
457462 for i in range (len (sample_id_seqlens )):
458- if sum_seqlen + sample_id_seqlens [i ] <= self .max_seq_len_all_ranks :
463+ if sum_seqlen + sample_id_seqlens [i ][ 1 ] <= self .max_seq_len_all_ranks :
459464 single_microbatch .append (i )
460465 sum_seqlen += sample_id_seqlens [i ][1 ]
461466 else :
462467 groups .append (single_microbatch )
463- sample_id_groups .append (single_microbatch )
468+ packed_id_groups .append (single_microbatch )
464469 single_microbatch = [i ]
465470 sum_seqlen = sample_id_seqlens [i ][1 ]
466471
467- # we want the number of microbatches to be multiple of dp_size
472+ # we want the number of packed sequences to be multiple of dp_size
468473 # so we move few samples from previous microbatch
469474 # to the end of the microbatches if needed
470- num_microbatches_before = len (sample_id_groups )
471- if num_microbatches_before % self .dp_size != 0 :
472- remainder = num_microbatches_before % self .dp_size
475+ num_packed_sequence = len (packed_id_groups )
476+ if num_packed_sequence % self .dp_size != 0 :
477+ remainder = num_packed_sequence % self .dp_size
473478 num_to_move = self .dp_size - remainder
474- i = num_microbatches_before - 1
479+ i = num_packed_sequence - 1
475480 while num_to_move > 0 :
476481 assert i > 0 , "Not enough samples to move"
477- if len (sample_id_groups [i ]) > 1 :
478- seq_id = sample_id_groups [i ].pop ()
479- sample_id_groups [i ].append (seq_id )
482+ if len (packed_id_groups [i ]) > 1 :
483+ seq_id = packed_id_groups [i ].pop ()
484+ packed_id_groups [i ].append (seq_id )
480485 num_to_move -= 1
481486 else :
482487 i -= 1
488+
489+ num_micro_batches = int (len (packed_id_groups ) / self .dp_size )
490+ for i in range (num_micro_batches ):
491+ sample_id_groups .append ([])
492+ for j in range (self .cp_size * self .dp_size ):
493+ seq_id = int (i * self .dp_size + j / self .cp_size )
494+ sample_id_groups [i ].append (packed_id_groups [seq_id ])
495+
483496 return groups , sample_id_groups
484497
485498
0 commit comments