Skip to content

Commit 838df92

Browse files
authored
[https://nvbugs/5670793][fix] Solve trtllm-serve launch_disaggregated… (#9324)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent 2cde4e4 commit 838df92

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
parse_disagg_config_file,
2626
parse_metadata_server_config_file)
2727
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
28-
from tensorrt_llm.llmapi.mpi_session import find_free_port
28+
from tensorrt_llm.llmapi.mpi_session import find_free_ipc_addr
2929
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
3030
from tensorrt_llm.logger import logger, severity_map
3131
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer
@@ -641,10 +641,10 @@ def _launch_disaggregated_leader(sub_comm, instance_idx: int, config_file: str,
641641

642642
# This mimics the behavior of trtllm-llmapi-launch
643643
# TODO: Make the port allocation atomic
644-
free_port = find_free_port()
644+
free_ipc_addr = find_free_ipc_addr()
645645
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS] = "1"
646-
os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR.
647-
value] = f"tcp://127.0.0.1:{free_port}"
646+
os.environ[
647+
LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR.value] = free_ipc_addr
648648
os.environ[DisaggLauncherEnvs.TLLM_DISAGG_RUN_REMOTE_MPI_SESSION_CLIENT.
649649
value] = "1"
650650
os.environ[DisaggLauncherEnvs.TLLM_DISAGG_INSTANCE_IDX] = str(instance_idx)

tensorrt_llm/llmapi/mpi_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,13 @@ def find_free_port() -> int:
544544
return s.getsockname()[1]
545545

546546

547+
def find_free_ipc_addr() -> str:
548+
import os
549+
import tempfile
550+
import uuid
551+
return f'ipc://{os.path.join(tempfile.gettempdir(), "rpc_" + str(uuid.uuid4()))}'
552+
553+
547554
def get_mpi_world_size() -> int:
548555
# avoid cyclic import
549556
from ..executor.utils import get_spawn_proxy_process_env

0 commit comments

Comments
 (0)