@@ -48,6 +48,10 @@ def task():
4848 '''
4949
5050 state = None
51+ # Global MPICommExecutor instance to be reused across multiple MpiCommSession instances
52+ # This is necessary because MPICommExecutor can only be created once per MPI process
53+ _global_comm_executor = None
54+ _global_mpi_pool = None
5155
5256 @staticmethod
5357 def is_initialized () -> bool :
@@ -183,6 +187,7 @@ def __init__(self, comm=None, n_workers: int = 1):
183187 self .n_workers = n_workers
184188 self .thread_pool : Optional [ThreadPoolExecutor ] = None
185189 self .mpi_pool : Optional [MPIPoolExecutor ] = None
190+ self .owns_mpi_pool = False # Track if this instance owns the mpi_pool
186191
187192 if n_workers <= 0 :
188193 raise ValueError (
@@ -230,9 +235,11 @@ def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:
230235 return [future .result () for future in futures ]
231236
232237 def shutdown (self , wait = True ):
233- if self .mpi_pool is not None :
238+ # Only shutdown the mpi_pool if this instance created it
239+ # For shared global mpi_pool, we don't shut it down
240+ if self .mpi_pool is not None and self .owns_mpi_pool :
234241 self .mpi_pool .shutdown (wait = wait )
235- self .mpi_pool = None
242+ self .mpi_pool = None
236243 if self .thread_pool is not None :
237244 self .thread_pool .shutdown (wait = wait )
238245 self .thread_pool = None
@@ -244,8 +251,36 @@ def _start_mpi_pool(self):
244251 assert not self .mpi_pool , 'MPI session already started'
245252
246253 self .thread_pool = ThreadPoolExecutor (max_workers = 2 )
247- comm_executor = MPICommExecutor (self .comm )
248- self .mpi_pool = comm_executor .__enter__ ()
254+
255+ # Use global MPICommExecutor if using COMM_WORLD
256+ # This is necessary because MPICommExecutor can only be created once per MPI process
257+ logger_debug (
258+ f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ ENABLE_MULTI_DEVICE } , self.comm={ self .comm } \n " ,
259+ "grey" )
260+ if ENABLE_MULTI_DEVICE :
261+ logger_debug (
262+ f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: { self .comm == mpi4py .MPI .COMM_WORLD } \n " ,
263+ "grey" )
264+ if ENABLE_MULTI_DEVICE and self .comm == mpi4py .MPI .COMM_WORLD :
265+ if MPINodeState ._global_comm_executor is None :
266+ logger_debug ("Creating global MPICommExecutor for COMM_WORLD\n " ,
267+ "yellow" )
268+ MPINodeState ._global_comm_executor = MPICommExecutor (self .comm )
269+ MPINodeState ._global_mpi_pool = MPINodeState ._global_comm_executor .__enter__ (
270+ )
271+ else :
272+ logger_debug ("Reusing global MPICommExecutor for COMM_WORLD\n " ,
273+ "yellow" )
274+ self .mpi_pool = MPINodeState ._global_mpi_pool
275+ self .owns_mpi_pool = False
276+ else :
277+ logger_debug (
278+ f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n " ,
279+ "grey" )
280+ # For non-COMM_WORLD communicators, create a new executor
281+ comm_executor = MPICommExecutor (self .comm )
282+ self .mpi_pool = comm_executor .__enter__ ()
283+ self .owns_mpi_pool = True
249284
250285 def __del__ (self ):
251286 self .shutdown_abort ()
@@ -264,9 +299,35 @@ class RemoteTask(NamedTuple):
264299class RemoteMpiCommSessionClient (MpiSession ):
265300 '''
266301 RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.
302+
303+ Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support
304+ one connection at a time. Multiple LLM instances will reuse the same client connection.
267305 '''
306+ _global_instance = None
307+ _global_instance_lock = threading .Lock ()
308+
309+ def __new__ (cls , addr : str , hmac_key : Optional [bytes ] = None ):
310+ # Implement singleton pattern to reuse the same client connection
311+ # for multiple LLM instances, since PAIR sockets only support one connection
312+ with cls ._global_instance_lock :
313+ if cls ._global_instance is None or cls ._global_instance .addr != addr :
314+ logger_debug (
315+ f"Creating new global RemoteMpiCommSessionClient for { addr } \n " ,
316+ "yellow" )
317+ instance = super ().__new__ (cls )
318+ cls ._global_instance = instance
319+ instance ._initialized = False
320+ else :
321+ logger_debug (
322+ f"Reusing existing global RemoteMpiCommSessionClient for { addr } \n " ,
323+ "yellow" )
324+ return cls ._global_instance
268325
269326 def __init__ (self , addr : str , hmac_key : Optional [bytes ] = None ):
327+ # Only initialize once
328+ if self ._initialized :
329+ return
330+
270331 # FIXME: this is a hack to avoid circular import, resolve later
271332 from tensorrt_llm .executor .ipc import ZeroMqQueue
272333 self .addr = addr
@@ -277,6 +338,7 @@ def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
277338 socket_type = zmq .PAIR ,
278339 use_hmac_encryption = bool (hmac_key ))
279340 self ._is_shutdown = False
341+ self ._initialized = True
280342
281343 def submit (self ,
282344 task : Callable [..., T ],
@@ -329,10 +391,16 @@ def abort(self):
329391 self .shutdown ()
330392
331393 def shutdown (self , wait = True ):
332- pass
394+ # NOTE: We do NOT close the queue or mark as shutdown for the singleton instance.
395+ # The RemoteMpiCommSessionClient is a global singleton that's reused across multiple
396+ # LLM instances. Marking it as shutdown would prevent subsequent LLM instances from
397+ # using it. The connection stays open for the entire lifetime of the mgmn setup.
398+ logger_debug (
399+ f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n " ,
400+ "grey" )
333401
334402 def shutdown_abort (self , grace : float = 60 , reason = None ):
335- pass
403+ self . shutdown ()
336404
337405
338406class RemoteMpiCommSessionServer ():
@@ -393,7 +461,26 @@ def task_wrapper(task: Callable[..., T], *args, **kwargs) -> T:
393461 def serve (self ):
394462 logger_debug (f"RemoteMpiCommSessionServer listening on { self .addr } \n " ,
395463 "yellow" )
464+ pending_futures = []
396465 while True :
466+ # Wait for any pending futures from previous tasks to complete
467+ # This ensures all ranks are ready before accepting the next task
468+ if pending_futures :
469+ logger_debug (
470+ f"RemoteMpiCommSessionServer waiting for { len (pending_futures )} pending futures to complete\n " ,
471+ "grey" )
472+ for future in pending_futures :
473+ try :
474+ future .result () # Wait for completion
475+ except Exception as e :
476+ print_colored (
477+ f"RemoteMpiCommSessionServer future failed with exception: { e } \n " ,
478+ "red" )
479+ pending_futures .clear ()
480+ logger_debug (
481+ "RemoteMpiCommSessionServer all pending futures completed\n " ,
482+ "grey" )
483+
397484 message : Optional [RemoteTask ] = self .queue .get ()
398485 if message is None :
399486 logger_debug (
@@ -410,6 +497,8 @@ def serve(self):
410497 * message .args , ** message .kwargs )
411498 self .num_results = self .session .n_workers
412499 assert len (futures ) == self .num_results == mpi_world_size ()
500+ # Store futures to wait for them before the next task
501+ pending_futures = list (futures )
413502 if message .sync :
414503 for future in futures :
415504 future .add_done_callback (self .mpi_future_callback )
0 commit comments