Skip to content

Commit 2481d0d

Browse files
committed
make broadcast methods abstract
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 746512c commit 2481d0d

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
116116
def allgather(self, obj, root=0):
117117
pass
118118

119+
@abstractmethod
120+
def tp_broadcast(self, obj, root=0, **kwargs):
121+
pass
122+
123+
@abstractmethod
124+
def cp_broadcast(self, obj, root=0, **kwargs):
125+
pass
126+
127+
def tp_cp_broadcast(self, obj, root=0, **kwargs):
128+
"""Broadcast object across both TP and CP groups.
129+
130+
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
131+
First broadcasts within the TP group, then within the CP group.
132+
"""
133+
if self.tp_size > 1:
134+
obj = self.tp_broadcast(obj, root=root, **kwargs)
135+
if self.cp_size > 1:
136+
obj = self.cp_broadcast(obj, root=root, **kwargs)
137+
return obj
138+
119139

120140
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
121141
"""
@@ -407,30 +427,26 @@ def create_cp_comm(self):
407427
def cp_allgather(self, obj):
408428
return self.cp_comm.allgather(obj)
409429

410-
def cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
430+
def cp_broadcast(self,
431+
obj,
432+
root=0,
433+
chunk_size: int = 4 * 1024 * 1024,
434+
**kwargs):
411435
comm = self.cp_comm
412436
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
413437

414-
def tp_cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
415-
"""Broadcast object across both TP and CP groups.
416-
417-
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
418-
First broadcasts within the TP group, then within the CP group.
419-
"""
420-
if self.tp_size > 1:
421-
obj = self.tp_broadcast(obj, root=root, chunk_size=chunk_size)
422-
if self.cp_size > 1:
423-
obj = self.cp_broadcast(obj, root=root, chunk_size=chunk_size)
424-
return obj
425-
426438
def tp_allgather(self, obj):
427439
return self.tp_comm.allgather(obj)
428440

429441
def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
430442
comm = self.tp_comm
431443
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)
432444

433-
def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
445+
def tp_broadcast(self,
446+
obj,
447+
root=0,
448+
chunk_size: int = 4 * 1024 * 1024,
449+
**kwargs):
434450
comm = self.tp_comm
435451
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
436452

@@ -726,7 +742,7 @@ def tp_gather(self, obj, dst=0):
726742
return output_list
727743

728744
@log_op
729-
def tp_broadcast(self, obj, root=0):
745+
def tp_broadcast(self, obj, root=0, **kwargs):
730746
if isinstance(obj, torch.Tensor):
731747
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
732748
return obj
@@ -740,7 +756,7 @@ def tp_broadcast(self, obj, root=0):
740756
return ret[0]
741757

742758
@log_op
743-
def cp_broadcast(self, obj, root=0):
759+
def cp_broadcast(self, obj, root=0, **kwargs):
744760
if isinstance(obj, torch.Tensor):
745761
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
746762
return obj
@@ -753,19 +769,6 @@ def cp_broadcast(self, obj, root=0):
753769
device=torch.device("cpu"))
754770
return ret[0]
755771

756-
@log_op
757-
def tp_cp_broadcast(self, obj, root=0):
758-
"""Broadcast object across both TP and CP groups.
759-
760-
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
761-
First broadcasts within the TP group, then within the CP group.
762-
"""
763-
if self.tp_size > 1:
764-
obj = self.tp_broadcast(obj, root=root)
765-
if self.cp_size > 1:
766-
obj = self.cp_broadcast(obj, root=root)
767-
return obj
768-
769772
@log_op
770773
def pp_allgather(self, obj):
771774
if isinstance(obj, torch.Tensor):

0 commit comments

Comments
 (0)