@@ -600,25 +600,6 @@ def _broadcast_cu_seqlens(cu_seqlens):
600600 _broadcast (batch ['loss_mask' ])
601601 _broadcast (batch ['attention_mask' ])
602602
603- def _broadcast_cu_seqlens (cu_seqlens ):
604- dev = torch .cuda .current_device ()
605-
606- n = 0 if cu_seqlens is None else int (cu_seqlens .numel ())
607- n_tensor = torch .tensor (n , dtype = torch .int64 , device = dev )
608- _broadcast (n_tensor )
609-
610- if n == 0 :
611- buf = torch .empty (0 , dtype = torch .int32 , device = dev )
612- else :
613- assert isinstance (cu_seqlens , torch .Tensor )
614- assert cu_seqlens .dtype == torch .int32
615- assert cu_seqlens .shape [0 ] == 1 , "micro-batch-size must be 1 for packing"
616- buf = cu_seqlens .to (device = dev , non_blocking = True ).contiguous ()
617- _broadcast (buf )
618-
619- _broadcast_cu_seqlens (batch ['cu_seqlens' ])
620- _broadcast (batch ['max_seqlen' ])
621-
622603 else :
623604 if args .hybrid_context_parallel :
624605 seq_len = torch .tensor (0 , dtype = torch .int32 , device = torch .cuda .current_device ())
@@ -657,21 +638,15 @@ def _broadcast_cu_seqlens(cu_seqlens):
657638 device = torch .cuda .current_device (),
658639 )
659640 cu_seqlens = None
660- if args .sft :
641+ if args .hybrid_context_parallel or args . sft :
661642 max_seqlen = torch .empty (
662643 1 ,
663644 dtype = torch .int32 ,
664645 device = torch .cuda .current_device (),
665646 )
666647 else :
667648 max_seqlen = None
668-
669- cu_seqlens = None
670- max_seqlen = torch .empty (
671- 1 ,
672- dtype = torch .int32 ,
673- device = torch .cuda .current_device (),
674- ) if args .hybrid_context_parallel else None
649+
675650 local_cp_size = torch .empty (
676651 1 ,
677652 dtype = torch .int32 ,
@@ -726,24 +701,6 @@ def _broadcast_cu_seqlens():
726701 _broadcast (loss_mask )
727702 _broadcast (attention_mask )
728703
729- def _broadcast_cu_seqlens ():
730- dev = torch .cuda .current_device ()
731-
732- n = torch .empty ((), dtype = torch .int64 , device = dev )
733- _broadcast (n )
734- n = int (n .item ())
735-
736- if n == 0 :
737- cu_seqlens = torch .empty (0 , dtype = torch .int32 , device = dev )
738- else :
739- cu_seqlens = torch .empty ((args .micro_batch_size , n ), dtype = torch .int32 , device = dev )
740- _broadcast (cu_seqlens )
741-
742- return cu_seqlens if n > 0 else None
743-
744- cu_seqlens = _broadcast_cu_seqlens ()
745- _broadcast (max_seqlen )
746-
747704 batch = {
748705 'tokens' : tokens ,
749706 'labels' : labels ,
0 commit comments