Skip to content

Commit 1551ed8

Browse files
authored
[https://nvbugs/5437384][test] CHERRY-PICK: fix trtllm-llmapi-launch multi tests (#8567)
Signed-off-by: Superjomn <[email protected]>
1 parent 4c5a8f4 commit 1551ed8

File tree

5 files changed

+231
-6
lines changed

5 files changed

+231
-6
lines changed

tensorrt_llm/llmapi/mpi_session.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def task():
4848
'''
4949

5050
state = None
51+
# Global MPICommExecutor instance to be reused across multiple MpiCommSession instances
52+
# This is necessary because MPICommExecutor can only be created once per MPI process
53+
_global_comm_executor = None
54+
_global_mpi_pool = None
5155

5256
@staticmethod
5357
def is_initialized() -> bool:
@@ -183,6 +187,7 @@ def __init__(self, comm=None, n_workers: int = 1):
183187
self.n_workers = n_workers
184188
self.thread_pool: Optional[ThreadPoolExecutor] = None
185189
self.mpi_pool: Optional[MPIPoolExecutor] = None
190+
self.owns_mpi_pool = False # Track if this instance owns the mpi_pool
186191

187192
if n_workers <= 0:
188193
raise ValueError(
@@ -230,9 +235,11 @@ def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:
230235
return [future.result() for future in futures]
231236

232237
def shutdown(self, wait=True):
233-
if self.mpi_pool is not None:
238+
# Only shutdown the mpi_pool if this instance created it
239+
# For shared global mpi_pool, we don't shut it down
240+
if self.mpi_pool is not None and self.owns_mpi_pool:
234241
self.mpi_pool.shutdown(wait=wait)
235-
self.mpi_pool = None
242+
self.mpi_pool = None
236243
if self.thread_pool is not None:
237244
self.thread_pool.shutdown(wait=wait)
238245
self.thread_pool = None
@@ -244,8 +251,36 @@ def _start_mpi_pool(self):
244251
assert not self.mpi_pool, 'MPI session already started'
245252

246253
self.thread_pool = ThreadPoolExecutor(max_workers=2)
247-
comm_executor = MPICommExecutor(self.comm)
248-
self.mpi_pool = comm_executor.__enter__()
254+
255+
# Use global MPICommExecutor if using COMM_WORLD
256+
# This is necessary because MPICommExecutor can only be created once per MPI process
257+
logger_debug(
258+
f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ENABLE_MULTI_DEVICE}, self.comm={self.comm}\n",
259+
"grey")
260+
if ENABLE_MULTI_DEVICE:
261+
logger_debug(
262+
f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: {self.comm == mpi4py.MPI.COMM_WORLD}\n",
263+
"grey")
264+
if ENABLE_MULTI_DEVICE and self.comm == mpi4py.MPI.COMM_WORLD:
265+
if MPINodeState._global_comm_executor is None:
266+
logger_debug("Creating global MPICommExecutor for COMM_WORLD\n",
267+
"yellow")
268+
MPINodeState._global_comm_executor = MPICommExecutor(self.comm)
269+
MPINodeState._global_mpi_pool = MPINodeState._global_comm_executor.__enter__(
270+
)
271+
else:
272+
logger_debug("Reusing global MPICommExecutor for COMM_WORLD\n",
273+
"yellow")
274+
self.mpi_pool = MPINodeState._global_mpi_pool
275+
self.owns_mpi_pool = False
276+
else:
277+
logger_debug(
278+
f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n",
279+
"grey")
280+
# For non-COMM_WORLD communicators, create a new executor
281+
comm_executor = MPICommExecutor(self.comm)
282+
self.mpi_pool = comm_executor.__enter__()
283+
self.owns_mpi_pool = True
249284

250285
def __del__(self):
251286
self.shutdown_abort()
@@ -264,9 +299,35 @@ class RemoteTask(NamedTuple):
264299
class RemoteMpiCommSessionClient(MpiSession):
265300
'''
266301
RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.
302+
303+
Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support
304+
one connection at a time. Multiple LLM instances will reuse the same client connection.
267305
'''
306+
_global_instance = None
307+
_global_instance_lock = threading.Lock()
308+
309+
def __new__(cls, addr: str, hmac_key: Optional[bytes] = None):
310+
# Implement singleton pattern to reuse the same client connection
311+
# for multiple LLM instances, since PAIR sockets only support one connection
312+
with cls._global_instance_lock:
313+
if cls._global_instance is None or cls._global_instance.addr != addr:
314+
logger_debug(
315+
f"Creating new global RemoteMpiCommSessionClient for {addr}\n",
316+
"yellow")
317+
instance = super().__new__(cls)
318+
cls._global_instance = instance
319+
instance._initialized = False
320+
else:
321+
logger_debug(
322+
f"Reusing existing global RemoteMpiCommSessionClient for {addr}\n",
323+
"yellow")
324+
return cls._global_instance
268325

269326
def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
327+
# Only initialize once
328+
if self._initialized:
329+
return
330+
270331
# FIXME: this is a hack to avoid circular import, resolve later
271332
from tensorrt_llm.executor.ipc import ZeroMqQueue
272333
self.addr = addr
@@ -277,6 +338,7 @@ def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
277338
socket_type=zmq.PAIR,
278339
use_hmac_encryption=bool(hmac_key))
279340
self._is_shutdown = False
341+
self._initialized = True
280342

281343
def submit(self,
282344
task: Callable[..., T],
@@ -329,10 +391,16 @@ def abort(self):
329391
self.shutdown()
330392

331393
def shutdown(self, wait=True):
332-
pass
394+
# NOTE: We do NOT close the queue or mark as shutdown for the singleton instance.
395+
# The RemoteMpiCommSessionClient is a global singleton that's reused across multiple
396+
# LLM instances. Marking it as shutdown would prevent subsequent LLM instances from
397+
# using it. The connection stays open for the entire lifetime of the mgmn setup.
398+
logger_debug(
399+
f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n",
400+
"grey")
333401

334402
def shutdown_abort(self, grace: float = 60, reason=None):
335-
pass
403+
self.shutdown()
336404

337405

338406
class RemoteMpiCommSessionServer():
@@ -393,7 +461,26 @@ def task_wrapper(task: Callable[..., T], *args, **kwargs) -> T:
393461
def serve(self):
394462
logger_debug(f"RemoteMpiCommSessionServer listening on {self.addr}\n",
395463
"yellow")
464+
pending_futures = []
396465
while True:
466+
# Wait for any pending futures from previous tasks to complete
467+
# This ensures all ranks are ready before accepting the next task
468+
if pending_futures:
469+
logger_debug(
470+
f"RemoteMpiCommSessionServer waiting for {len(pending_futures)} pending futures to complete\n",
471+
"grey")
472+
for future in pending_futures:
473+
try:
474+
future.result() # Wait for completion
475+
except Exception as e:
476+
print_colored(
477+
f"RemoteMpiCommSessionServer future failed with exception: {e}\n",
478+
"red")
479+
pending_futures.clear()
480+
logger_debug(
481+
"RemoteMpiCommSessionServer all pending futures completed\n",
482+
"grey")
483+
397484
message: Optional[RemoteTask] = self.queue.get()
398485
if message is None:
399486
logger_debug(
@@ -410,6 +497,8 @@ def serve(self):
410497
*message.args, **message.kwargs)
411498
self.num_results = self.session.n_workers
412499
assert len(futures) == self.num_results == mpi_world_size()
500+
# Store futures to wait for them before the next task
501+
pending_futures = list(futures)
413502
if message.sync:
414503
for future in futures:
415504
future.add_done_callback(self.mpi_future_callback)

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ l0_dgx_h100:
4343
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
4444
# ------------- AutoDeploy tests ---------------
4545
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
46+
# llmapi
47+
- unittest/llmapi/test_mpi_session.py::test_llmapi_launch_multiple_tasks
4648
- condition:
4749
ranges:
4850
system_gpu_count:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import sys
3+
4+
cur_dir = os.path.dirname(os.path.abspath(__file__))
5+
6+
from tensorrt_llm import LLM
7+
from tensorrt_llm.llmapi import SamplingParams
8+
from tensorrt_llm.llmapi.utils import print_colored
9+
10+
# isort: off
11+
sys.path.append(os.path.join(cur_dir, '..'))
12+
from utils.llm_data import llm_models_root
13+
# isort: on
14+
15+
model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0"
16+
17+
18+
def run_llm_tp2():
19+
with LLM(model=model_path, tensor_parallel_size=2) as llm:
20+
sampling_params = SamplingParams(max_tokens=10, end_id=-1)
21+
for output in llm.generate(["Hello, my name is"], sampling_params):
22+
print(output)
23+
24+
25+
def run_multi_llm_tasks():
26+
for i in range(3):
27+
print_colored(f"Running LLM task {i}\n", "green")
28+
run_llm_tp2()
29+
print_colored(f"LLM task {i} completed\n", "green")
30+
31+
32+
if __name__ == "__main__":
33+
run_multi_llm_tasks()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
from typing import Literal
3+
4+
import click
5+
6+
from tensorrt_llm.executor.utils import LlmLauncherEnvs
7+
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
8+
from tensorrt_llm.llmapi.utils import print_colored
9+
10+
11+
def run_task(task_type: Literal["submit", "submit_sync"]):
12+
tasks = range(10)
13+
assert os.environ[
14+
LlmLauncherEnvs.
15+
TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
16+
client = RemoteMpiCommSessionClient(
17+
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR])
18+
19+
for task in tasks:
20+
if task_type == "submit":
21+
client.submit(print_colored, f"{task}\n", "green")
22+
elif task_type == "submit_sync":
23+
res = client.submit_sync(print_colored, f"{task}\n", "green")
24+
print(res)
25+
26+
27+
def run_multi_tasks(task_type: Literal["submit", "submit_sync"]):
28+
for i in range(3):
29+
print_colored(f"Running MPI comm task {i}\n", "green")
30+
run_task(task_type)
31+
print_colored(f"MPI comm task {i} completed\n", "green")
32+
33+
34+
@click.command()
35+
@click.option("--task_type",
36+
type=click.Choice(["submit", "submit_sync"]),
37+
default="submit")
38+
def main(task_type: Literal["submit", "submit_sync"]):
39+
run_multi_tasks(task_type)
40+
41+
42+
if __name__ == "__main__":
43+
main()

tests/unittest/llmapi/test_mpi_session.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,20 @@
55
from subprocess import PIPE, Popen
66
from typing import Literal
77

8+
cur_dir = os.path.dirname(os.path.abspath(__file__))
9+
810
import pytest
911

1012
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
1113
from tensorrt_llm.llmapi.mpi_session import (MPINodeState, MpiPoolSession,
1214
RemoteMpiCommSessionClient,
1315
split_mpi_env)
1416

17+
# isort: off
18+
sys.path.append(os.path.join(cur_dir, '..'))
19+
from utils.util import skip_single_gpu
20+
# isort: on
21+
1522

1623
def task0():
1724
if MPINodeState.state is None:
@@ -108,3 +115,54 @@ def task1():
108115
def test_split_mpi_env():
109116
session = MpiPoolSession(n_workers=4)
110117
session.submit_sync(task1)
118+
119+
120+
@skip_single_gpu
121+
@pytest.mark.parametrize(
122+
"task_script", ["_run_mpi_comm_task.py", "_run_multi_mpi_comm_tasks.py"])
123+
def test_llmapi_launch_multiple_tasks(task_script: str):
124+
"""
125+
Test that the trtllm-llmapi-launch can run multiple tasks.
126+
"""
127+
cur_dir = os.path.dirname(os.path.abspath(__file__))
128+
test_file = os.path.join(cur_dir, "_run_multi_llm_tasks.py")
129+
assert os.path.exists(test_file), f"Test file {test_file} does not exist"
130+
command = [
131+
"mpirun", "-n", "2", "--allow-run-as-root", "trtllm-llmapi-launch",
132+
"python3", test_file
133+
]
134+
print(' '.join(command))
135+
136+
with Popen(command,
137+
env=os.environ,
138+
stdout=PIPE,
139+
stderr=PIPE,
140+
bufsize=1,
141+
start_new_session=True,
142+
universal_newlines=True,
143+
cwd=os.path.dirname(os.path.abspath(__file__))) as process:
144+
# Function to read from a stream and write to output
145+
def read_stream(stream, output_stream):
146+
for line in stream:
147+
output_stream.write(line)
148+
output_stream.flush()
149+
150+
# Create threads to read stdout and stderr concurrently
151+
stdout_thread = threading.Thread(target=read_stream,
152+
args=(process.stdout, sys.stdout))
153+
stderr_thread = threading.Thread(target=read_stream,
154+
args=(process.stderr, sys.stderr))
155+
156+
# Start both threads
157+
stdout_thread.start()
158+
stderr_thread.start()
159+
160+
# Wait for the process to complete
161+
return_code = process.wait()
162+
163+
# Wait for both threads to finish reading
164+
stdout_thread.join()
165+
stderr_thread.join()
166+
167+
if return_code != 0:
168+
raise subprocess.CalledProcessError(return_code, command)

0 commit comments

Comments
 (0)