@@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
116116 def allgather (self , obj , root = 0 ):
117117 pass
118118
119+ @abstractmethod
120+ def tp_broadcast (self , obj , root = 0 , ** kwargs ):
121+ pass
122+
123+ @abstractmethod
124+ def cp_broadcast (self , obj , root = 0 , ** kwargs ):
125+ pass
126+
127+ def tp_cp_broadcast (self , obj , root = 0 , ** kwargs ):
128+ """Broadcast object across both TP and CP groups.
129+
130+ This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
131+ First broadcasts within the TP group, then within the CP group.
132+ """
133+ if self .tp_size > 1 :
134+ obj = self .tp_broadcast (obj , root = root , ** kwargs )
135+ if self .cp_size > 1 :
136+ obj = self .cp_broadcast (obj , root = root , ** kwargs )
137+ return obj
138+
119139
120140def safe_broadcast (comm , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
121141 """
@@ -407,30 +427,26 @@ def create_cp_comm(self):
407427 def cp_allgather (self , obj ):
408428 return self .cp_comm .allgather (obj )
409429
410- def cp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
430+ def cp_broadcast (self ,
431+ obj ,
432+ root = 0 ,
433+ chunk_size : int = 4 * 1024 * 1024 ,
434+ ** kwargs ):
411435 comm = self .cp_comm
412436 return safe_broadcast (comm , obj , root = root , chunk_size = chunk_size )
413437
414- def tp_cp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
415- """Broadcast object across both TP and CP groups.
416-
417- This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
418- First broadcasts within the TP group, then within the CP group.
419- """
420- if self .tp_size > 1 :
421- obj = self .tp_broadcast (obj , root = root , chunk_size = chunk_size )
422- if self .cp_size > 1 :
423- obj = self .cp_broadcast (obj , root = root , chunk_size = chunk_size )
424- return obj
425-
426438 def tp_allgather (self , obj ):
427439 return self .tp_comm .allgather (obj )
428440
429441 def tp_gather (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
430442 comm = self .tp_comm
431443 return safe_gather (comm , obj , root = root , chunk_size = chunk_size )
432444
433- def tp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
445+ def tp_broadcast (self ,
446+ obj ,
447+ root = 0 ,
448+ chunk_size : int = 4 * 1024 * 1024 ,
449+ ** kwargs ):
434450 comm = self .tp_comm
435451 return safe_broadcast (comm , obj , root = root , chunk_size = chunk_size )
436452
@@ -715,7 +731,7 @@ def tp_gather(self, obj, dst=0):
715731 return output_list
716732
717733 @log_op
718- def tp_broadcast (self , obj , root = 0 ):
734+ def tp_broadcast (self , obj , root = 0 , ** kwargs ):
719735 if isinstance (obj , torch .Tensor ):
720736 dist .broadcast (obj , src = root , group = self .mapping .tp_group_pg )
721737 return obj
@@ -729,7 +745,7 @@ def tp_broadcast(self, obj, root=0):
729745 return ret [0 ]
730746
731747 @log_op
732- def cp_broadcast (self , obj , root = 0 ):
748+ def cp_broadcast (self , obj , root = 0 , ** kwargs ):
733749 if isinstance (obj , torch .Tensor ):
734750 dist .broadcast (obj , src = root , group = self .mapping .cp_group_pg )
735751 return obj
@@ -742,19 +758,6 @@ def cp_broadcast(self, obj, root=0):
742758 device = torch .device ("cpu" ))
743759 return ret [0 ]
744760
745- @log_op
746- def tp_cp_broadcast (self , obj , root = 0 ):
747- """Broadcast object across both TP and CP groups.
748-
749- This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
750- First broadcasts within the TP group, then within the CP group.
751- """
752- if self .tp_size > 1 :
753- obj = self .tp_broadcast (obj , root = root )
754- if self .cp_size > 1 :
755- obj = self .cp_broadcast (obj , root = root )
756- return obj
757-
758761 @log_op
759762 def pp_allgather (self , obj ):
760763 if isinstance (obj , torch .Tensor ):
0 commit comments