Skip to content

Commit cd4e639

Browse files
authored
[None][feat] Async pp send. (#9952)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 4cc4cbe commit cd4e639

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
except Exception:
1717
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
1818

19+
from tensorrt_llm._torch.hostfunc import hostfunc
1920
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
2021
mpi_disabled, mpi_isend, mpi_isend_object,
2122
mpi_recv, mpi_recv_object, mpi_send,
@@ -782,18 +783,57 @@ def pp_broadcast(self, obj, root=0):
782783
return ret[0]
783784

784785

785-
class PPCommNCCL:
786+
class PPCommBase:
786787

787788
def __init__(self, global_mapping: Mapping):
788789
self.mapping = global_mapping
790+
self.tensor_ready_event = torch.cuda.Event()
791+
self.send_stream = torch.cuda.Stream()
792+
self.tensor_cache = {}
793+
794+
def _cache_tensor(self, tensor: torch.Tensor):
795+
cache_id = id(tensor)
796+
self.tensor_cache[cache_id] = tensor
797+
798+
@hostfunc
799+
def _release_tensor(self, tensor: torch.Tensor):
800+
cache_id = id(tensor)
801+
del self.tensor_cache[cache_id]
802+
803+
@abstractmethod
804+
def direct_send(self, tensor: torch.Tensor, dest: int):
805+
raise NotImplementedError("direct_send is not implemented")
806+
807+
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
808+
if dest is None:
809+
dest = self.mapping.next_pp_rank()
810+
811+
# NCCL send kernel in send_stream cannot be captured,
812+
# so we send in the current stream instead in CUDA graph cases.
813+
if torch.cuda.is_current_stream_capturing():
814+
self.direct_send(tensor, dest)
815+
return
816+
817+
self.tensor_ready_event.record()
818+
with torch.cuda.stream(self.send_stream):
819+
self.tensor_ready_event.wait()
820+
# tensor may be released before NCCL send finished,
821+
# so we cache it first and release it after send finished.
822+
self._cache_tensor(tensor)
823+
self.direct_send(tensor, dest)
824+
self._release_tensor(tensor)
825+
826+
827+
class PPCommNCCL(PPCommBase):
828+
829+
def __init__(self, global_mapping: Mapping):
830+
super().__init__(global_mapping)
789831
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
790832
self.mapping.world_size,
791833
self.mapping.rank,
792834
)
793835

794-
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
795-
if dest is None:
796-
dest = self.mapping.next_pp_rank()
836+
def direct_send(self, tensor: torch.Tensor, dest: int):
797837
self.nccl_comm.send(tensor, dest)
798838

799839
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
@@ -802,21 +842,18 @@ def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
802842
self.nccl_comm.recv(tensor, src)
803843

804844

805-
class PPCommTorch:
845+
class PPCommTorch(PPCommBase):
806846

807847
def __init__(self, global_mapping: Mapping):
808-
self.mapping = global_mapping
848+
super().__init__(global_mapping)
809849
self.pg = self.mapping.pp_group_pg
810850
self.pg_group = self.mapping.pp_group
811851

812852
def _global_to_local_rank(self, global_rank: int):
813853
assert global_rank in self.pg_group
814854
return self.pg_group.index(global_rank)
815855

816-
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
817-
if dest is None:
818-
dest = self.mapping.next_pp_rank()
819-
856+
def direct_send(self, tensor: torch.Tensor, dest: int):
820857
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
821858

822859
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):

0 commit comments

Comments
 (0)