Skip to content

Commit fcda1a1

Browse files
authored
[None][fix] disable async pp send for ray cases. (#9959)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent f6b0ddd commit fcda1a1

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

jenkins/L0_MergeRequest.groovy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
712712
"tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py",
713713
"tensorrt_llm/_torch/custom_ops/torch_custom_ops.py",
714714
"tensorrt_llm/_torch/custom_ops/userbuffers_custom_ops.py",
715+
"tensorrt_llm/_torch/distributed/",
715716
"tensorrt_llm/_torch/models/modeling_llama.py",
716717
"tensorrt_llm/_torch/models/modeling_qwen3_next.py",
717718
"tensorrt_llm/_torch/modules/fused_moe/",

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,13 @@ def _global_to_local_rank(self, global_rank: int):
856856
def direct_send(self, tensor: torch.Tensor, dest: int):
857857
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
858858

859+
# TODO: support async pp send for PPCommTorch
860+
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
861+
if dest is None:
862+
dest = self.mapping.next_pp_rank()
863+
864+
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
865+
859866
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
860867
if src is None:
861868
src = self.mapping.prev_pp_rank()

0 commit comments

Comments
 (0)