Skip to content

Commit 1f3afb8

Browse files
authored
[None][feat] Implement send_object for TorchDist. (#10213)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent ec8a388 commit 1f3afb8

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,7 @@ def recv_object(self, src, tag=0):
622622

623623
@log_op
624624
def send_object(self, obj, dest, tag=0):
625-
raise NotImplementedError(
626-
"send_object is not implemented for TorchDist")
625+
self.isend_object(obj, dest, tag).wait()
627626

628627
@log_op
629628
def isend_object(self, obj, dest, tag=0):
@@ -640,16 +639,6 @@ def isend_object(self, obj, dest, tag=0):
640639
works.append(torch.distributed.isend(input_tensor, dst=dest, tag=tag))
641640
return MultiHandleWrapper(works)
642641

643-
@log_op
644-
def recv_object_from_isend(self, src, tag):
645-
size_tensor = torch.tensor([0], dtype=torch.int32)
646-
torch.distributed.recv(size_tensor, src=src, tag=tag)
647-
bytes_size = size_tensor.item()
648-
recv_tensor = torch.empty(bytes_size, dtype=torch.uint8)
649-
torch.distributed.recv(recv_tensor, src=src, tag=tag)
650-
return _tensor_to_object(recv_tensor, bytes_size,
651-
torch.distributed.group.WORLD)
652-
653642
@log_op
654643
def allreduce(self,
655644
obj: int | float | torch.Tensor,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tensorrt_llm._torch.pyexecutor.resource_manager import (
2323
ResourceManagerType, request_context)
2424
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
25-
mpi_disabled, nvtx_range, trace_func)
25+
nvtx_range, trace_func)
2626
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
2727
FinishReason, InflightBatchingStats,
2828
IterationStats, KvCacheStats,
@@ -229,7 +229,6 @@ def __init__(self,
229229
self.num_scheduled_requests: int = 0
230230
self.benchmark_req_queues_size = int(
231231
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
232-
self._disable_mpi = mpi_disabled()
233232

234233
# list of requests in each PP micro batch
235234
self.num_micro_batches = self.dist.pp_size
@@ -1094,11 +1093,9 @@ def _executor_loop_pp(self):
10941093
if previous_batch is not None:
10951094
sample_state = previous_batch.sample_state
10961095
if not self.dist.is_last_pp_rank:
1097-
recv_object_funct = self.dist.recv_object_from_isend if self._disable_mpi \
1098-
else self.dist.recv_object
10991096
# Receive tokens from previous pp rank (w.r.t model forward direction)
11001097
with nvtx_range("recv_sample_state"):
1101-
sample_state.host = recv_object_funct(
1098+
sample_state.host = self.dist.recv_object(
11021099
src=self.dist.prev_pp_rank,
11031100
tag=tag,
11041101
)

0 commit comments

Comments
 (0)