@@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
116116 def allgather (self , obj , root = 0 ):
117117 pass
118118
119+ @abstractmethod
120+ def tp_broadcast (self , obj , root = 0 , ** kwargs ):
121+ pass
122+
123+ @abstractmethod
124+ def cp_broadcast (self , obj , root = 0 , ** kwargs ):
125+ pass
126+
127+ def tp_cp_broadcast (self , obj , root = 0 , ** kwargs ):
128+ """Broadcast object across both TP and CP groups.
129+
130+ This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
131+ First broadcasts within the TP group, then within the CP group.
132+ """
133+ if self .tp_size > 1 :
134+ obj = self .tp_broadcast (obj , root = root , ** kwargs )
135+ if self .cp_size > 1 :
136+ obj = self .cp_broadcast (obj , root = root , ** kwargs )
137+ return obj
138+
119139
120140def safe_broadcast (comm , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
121141 """
@@ -407,30 +427,26 @@ def create_cp_comm(self):
407427 def cp_allgather (self , obj ):
408428 return self .cp_comm .allgather (obj )
409429
410- def cp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
430+ def cp_broadcast (self ,
431+ obj ,
432+ root = 0 ,
433+ chunk_size : int = 4 * 1024 * 1024 ,
434+ ** kwargs ):
411435 comm = self .cp_comm
412436 return safe_broadcast (comm , obj , root = root , chunk_size = chunk_size )
413437
414- def tp_cp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
415- """Broadcast object across both TP and CP groups.
416-
417- This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
418- First broadcasts within the TP group, then within the CP group.
419- """
420- if self .tp_size > 1 :
421- obj = self .tp_broadcast (obj , root = root , chunk_size = chunk_size )
422- if self .cp_size > 1 :
423- obj = self .cp_broadcast (obj , root = root , chunk_size = chunk_size )
424- return obj
425-
426438 def tp_allgather (self , obj ):
427439 return self .tp_comm .allgather (obj )
428440
429441 def tp_gather (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
430442 comm = self .tp_comm
431443 return safe_gather (comm , obj , root = root , chunk_size = chunk_size )
432444
433- def tp_broadcast (self , obj , root = 0 , chunk_size : int = 4 * 1024 * 1024 ):
445+ def tp_broadcast (self ,
446+ obj ,
447+ root = 0 ,
448+ chunk_size : int = 4 * 1024 * 1024 ,
449+ ** kwargs ):
434450 comm = self .tp_comm
435451 return safe_broadcast (comm , obj , root = root , chunk_size = chunk_size )
436452
@@ -726,7 +742,7 @@ def tp_gather(self, obj, dst=0):
726742 return output_list
727743
728744 @log_op
729- def tp_broadcast (self , obj , root = 0 ):
745+ def tp_broadcast (self , obj , root = 0 , ** kwargs ):
730746 if isinstance (obj , torch .Tensor ):
731747 dist .broadcast (obj , src = root , group = self .mapping .tp_group_pg )
732748 return obj
@@ -740,7 +756,7 @@ def tp_broadcast(self, obj, root=0):
740756 return ret [0 ]
741757
742758 @log_op
743- def cp_broadcast (self , obj , root = 0 ):
759+ def cp_broadcast (self , obj , root = 0 , ** kwargs ):
744760 if isinstance (obj , torch .Tensor ):
745761 dist .broadcast (obj , src = root , group = self .mapping .cp_group_pg )
746762 return obj
@@ -753,19 +769,6 @@ def cp_broadcast(self, obj, root=0):
753769 device = torch .device ("cpu" ))
754770 return ret [0 ]
755771
756- @log_op
757- def tp_cp_broadcast (self , obj , root = 0 ):
758- """Broadcast object across both TP and CP groups.
759-
760- This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
761- First broadcasts within the TP group, then within the CP group.
762- """
763- if self .tp_size > 1 :
764- obj = self .tp_broadcast (obj , root = root )
765- if self .cp_size > 1 :
766- obj = self .cp_broadcast (obj , root = root )
767- return obj
768-
769772 @log_op
770773 def pp_allgather (self , obj ):
771774 if isinstance (obj , torch .Tensor ):
0 commit comments