1616except Exception :
1717 MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
1818
19+ from tensorrt_llm ._torch .hostfunc import hostfunc
1920from tensorrt_llm ._utils import (mpi_allgather , mpi_barrier , mpi_comm ,
2021 mpi_disabled , mpi_isend , mpi_isend_object ,
2122 mpi_recv , mpi_recv_object , mpi_send ,
@@ -782,18 +783,57 @@ def pp_broadcast(self, obj, root=0):
782783 return ret [0 ]
783784
784785
785- class PPCommNCCL :
786+ class PPCommBase :
786787
787788 def __init__ (self , global_mapping : Mapping ):
788789 self .mapping = global_mapping
790+ self .tensor_ready_event = torch .cuda .Event ()
791+ self .send_stream = torch .cuda .Stream ()
792+ self .tensor_cache = {}
793+
794+ def _cache_tensor (self , tensor : torch .Tensor ):
795+ cache_id = id (tensor )
796+ self .tensor_cache [cache_id ] = tensor
797+
798+ @hostfunc
799+ def _release_tensor (self , tensor : torch .Tensor ):
800+ cache_id = id (tensor )
801+ del self .tensor_cache [cache_id ]
802+
803+ @abstractmethod
804+ def direct_send (self , tensor : torch .Tensor , dest : int ):
805+ raise NotImplementedError ("direct_send is not implemented" )
806+
807+ def send (self , tensor : torch .Tensor , dest : Optional [int ] = None ):
808+ if dest is None :
809+ dest = self .mapping .next_pp_rank ()
810+
811+ # NCCL send kernel in send_stream cannot be captured,
812+ # so we send in the current stream instead in CUDA graph cases.
813+ if torch .cuda .is_current_stream_capturing ():
814+ self .direct_send (tensor , dest )
815+ return
816+
817+ self .tensor_ready_event .record ()
818+ with torch .cuda .stream (self .send_stream ):
819+ self .tensor_ready_event .wait ()
820+ # tensor may be released before NCCL send finished,
821+ # so we cache it first and release it after send finished.
822+ self ._cache_tensor (tensor )
823+ self .direct_send (tensor , dest )
824+ self ._release_tensor (tensor )
825+
826+
827+ class PPCommNCCL (PPCommBase ):
828+
829+ def __init__ (self , global_mapping : Mapping ):
830+ super ().__init__ (global_mapping )
789831 self .nccl_comm = torch .classes .trtllm .NcclCommunicatorOp (
790832 self .mapping .world_size ,
791833 self .mapping .rank ,
792834 )
793835
794- def send (self , tensor : torch .Tensor , dest : Optional [int ] = None ):
795- if dest is None :
796- dest = self .mapping .next_pp_rank ()
836+ def direct_send (self , tensor : torch .Tensor , dest : int ):
797837 self .nccl_comm .send (tensor , dest )
798838
799839 def recv (self , tensor : torch .Tensor , src : Optional [int ] = None ):
@@ -802,21 +842,18 @@ def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
802842 self .nccl_comm .recv (tensor , src )
803843
804844
805- class PPCommTorch :
845+ class PPCommTorch ( PPCommBase ) :
806846
807847 def __init__ (self , global_mapping : Mapping ):
808- self . mapping = global_mapping
848+ super (). __init__ ( global_mapping )
809849 self .pg = self .mapping .pp_group_pg
810850 self .pg_group = self .mapping .pp_group
811851
812852 def _global_to_local_rank (self , global_rank : int ):
813853 assert global_rank in self .pg_group
814854 return self .pg_group .index (global_rank )
815855
816- def send (self , tensor : torch .Tensor , dest : Optional [int ] = None ):
817- if dest is None :
818- dest = self .mapping .next_pp_rank ()
819-
856+ def direct_send (self , tensor : torch .Tensor , dest : int ):
820857 self .pg .send ([tensor ], self ._global_to_local_rank (dest ), tag = 0 ).wait ()
821858
822859 def recv (self , tensor : torch .Tensor , src : Optional [int ] = None ):
0 commit comments