Skip to content

Commit defbd4e

Browse files
committed
make broadcast methods abstract
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 4b78b48 commit defbd4e

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
116116
def allgather(self, obj, root=0):
117117
pass
118118

119+
@abstractmethod
120+
def tp_broadcast(self, obj, root=0, **kwargs):
121+
pass
122+
123+
@abstractmethod
124+
def cp_broadcast(self, obj, root=0, **kwargs):
125+
pass
126+
127+
def tp_cp_broadcast(self, obj, root=0, **kwargs):
128+
"""Broadcast object across both TP and CP groups.
129+
130+
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
131+
First broadcasts within the TP group, then within the CP group.
132+
"""
133+
if self.tp_size > 1:
134+
obj = self.tp_broadcast(obj, root=root, **kwargs)
135+
if self.cp_size > 1:
136+
obj = self.cp_broadcast(obj, root=root, **kwargs)
137+
return obj
138+
119139

120140
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
121141
"""
@@ -407,30 +427,26 @@ def create_cp_comm(self):
407427
def cp_allgather(self, obj):
408428
return self.cp_comm.allgather(obj)
409429

410-
def cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
430+
def cp_broadcast(self,
431+
obj,
432+
root=0,
433+
chunk_size: int = 4 * 1024 * 1024,
434+
**kwargs):
411435
comm = self.cp_comm
412436
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
413437

414-
def tp_cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
415-
"""Broadcast object across both TP and CP groups.
416-
417-
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
418-
First broadcasts within the TP group, then within the CP group.
419-
"""
420-
if self.tp_size > 1:
421-
obj = self.tp_broadcast(obj, root=root, chunk_size=chunk_size)
422-
if self.cp_size > 1:
423-
obj = self.cp_broadcast(obj, root=root, chunk_size=chunk_size)
424-
return obj
425-
426438
def tp_allgather(self, obj):
427439
return self.tp_comm.allgather(obj)
428440

429441
def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
430442
comm = self.tp_comm
431443
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)
432444

433-
def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
445+
def tp_broadcast(self,
446+
obj,
447+
root=0,
448+
chunk_size: int = 4 * 1024 * 1024,
449+
**kwargs):
434450
comm = self.tp_comm
435451
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
436452

@@ -715,7 +731,7 @@ def tp_gather(self, obj, dst=0):
715731
return output_list
716732

717733
@log_op
718-
def tp_broadcast(self, obj, root=0):
734+
def tp_broadcast(self, obj, root=0, **kwargs):
719735
if isinstance(obj, torch.Tensor):
720736
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
721737
return obj
@@ -729,7 +745,7 @@ def tp_broadcast(self, obj, root=0):
729745
return ret[0]
730746

731747
@log_op
732-
def cp_broadcast(self, obj, root=0):
748+
def cp_broadcast(self, obj, root=0, **kwargs):
733749
if isinstance(obj, torch.Tensor):
734750
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
735751
return obj
@@ -742,19 +758,6 @@ def cp_broadcast(self, obj, root=0):
742758
device=torch.device("cpu"))
743759
return ret[0]
744760

745-
@log_op
746-
def tp_cp_broadcast(self, obj, root=0):
747-
"""Broadcast object across both TP and CP groups.
748-
749-
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
750-
First broadcasts within the TP group, then within the CP group.
751-
"""
752-
if self.tp_size > 1:
753-
obj = self.tp_broadcast(obj, root=root)
754-
if self.cp_size > 1:
755-
obj = self.cp_broadcast(obj, root=root)
756-
return obj
757-
758761
@log_op
759762
def pp_allgather(self, obj):
760763
if isinstance(obj, torch.Tensor):

0 commit comments

Comments
 (0)