Skip to content

Commit 1009039

Browse files
authored
[feature] use config_server to init nccl (#906)
1 parent d9281ee commit 1009039

File tree

7 files changed

+221
-10
lines changed

7 files changed

+221
-10
lines changed

lightllm/server/api_cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
150150
parser.add_argument(
151151
"--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch"
152152
)
153+
parser.add_argument(
154+
"--use_config_server_to_init_nccl",
155+
action="store_true",
156+
help="""use tcp store server started by config_server to init nccl, default is False, when set to True,
157+
the --nccl_host must equal to the config_server_host, and the --nccl_port must be unique for a config_server,
158+
dont use same nccl_port for different inference node, it will be critical error""",
159+
)
160+
153161
parser.add_argument(
154162
"--mode",
155163
type=str,

lightllm/server/api_start.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import signal
77
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
8-
from lightllm.utils.start_utils import process_manager
8+
from lightllm.utils.start_utils import process_manager, kill_recursive
99
from .metrics.manager import start_metric_manager
1010
from .embed_cache.manager import start_cache_manager
1111
from .visualserver.manager import start_visual_process
@@ -25,8 +25,8 @@ def setup_signal_handlers(http_server_process, process_manager):
2525
def signal_handler(sig, frame):
2626
if sig == signal.SIGINT:
2727
logger.info("Received SIGINT (Ctrl+C), forcing immediate exit...")
28-
if http_server_process and http_server_process.poll() is None:
29-
http_server_process.kill()
28+
if http_server_process:
29+
kill_recursive(http_server_process)
3030

3131
process_manager.terminate_all_processes()
3232
logger.info("All processes have been forcefully terminated.")
@@ -47,7 +47,7 @@ def signal_handler(sig, frame):
4747
logger.info("HTTP server has exited gracefully")
4848
else:
4949
logger.warning("HTTP server did not exit in time, killing it...")
50-
http_server_process.kill()
50+
kill_recursive(http_server_process)
5151

5252
process_manager.terminate_all_processes()
5353
logger.info("All processes have been terminated gracefully.")
@@ -82,6 +82,10 @@ def normal_or_p_d_start(args):
8282

8383
logger.info(f"use tgi api: {args.use_tgi_api}")
8484

85+
# 当使用config_server来初始化nccl时,nccl_host和config_server_host必须一致
86+
if args.use_config_server_to_init_nccl:
87+
assert args.config_server_host == args.nccl_host
88+
8589
assert (
8690
args.mem_fraction > 0 and args.mem_fraction < 1
8791
), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1."

lightllm/server/config_server/api_http.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import time
2+
import asyncio
3+
import base64
4+
import pickle
5+
import multiprocessing as mp
16
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
27
from threading import Lock
3-
from typing import Dict
8+
from typing import Dict, List
49
from fastapi.responses import JSONResponse
510
from lightllm.utils.log_utils import init_logger
611
from ..pd_io_struct import PD_Master_Obj
7-
import base64
8-
import pickle
9-
import os
10-
import requests
12+
from .nccl_tcp_store import start_tcp_store_server
13+
from lightllm.utils.envs_utils import get_env_start_args
14+
from lightllm.utils.process_check import start_parent_check_thread
15+
1116

1217
logger = init_logger(__name__)
1318
app = FastAPI()
@@ -112,3 +117,90 @@ async def allocate_global_unique_multimodal_id_range():
112117
end_id = global_multimodal_embedding_id
113118

114119
return {"start_id": start_id, "end_id": end_id}
120+
121+
122+
global_store_port_to_process: Dict[int, mp.Process] = {}
123+
global_store_port_to_client_states: Dict[int, List[bool]] = {}
124+
global_store_port_lock = asyncio.Lock()
125+
126+
127+
@app.get("/start_tcp_store_server")
128+
async def http_start_tcp_store_server(
129+
tcp_store_port: int = Query(...), rank_id: int = Query(...), world_size: int = Query(...)
130+
):
131+
"""
132+
Start a TCP store server for NCCL communication.
133+
134+
Args:
135+
tcp_store_port (int): The port number for the TCP store server.
136+
rank_id (int): The rank ID of inference process.
137+
world_size (int): The world size of nccl group.
138+
139+
Returns:
140+
dict: A dictionary containing the status of the server.
141+
"""
142+
global global_store_port_to_process
143+
global global_store_port_to_client_states
144+
global global_store_port_lock
145+
146+
args = get_env_start_args()
147+
148+
if rank_id == 0:
149+
async with global_store_port_lock:
150+
if tcp_store_port in global_store_port_to_client_states:
151+
logger.error(f"tcp store server {tcp_store_port} already started, rank_id 0 find client state exists")
152+
assert False, f"tcp store server {tcp_store_port} already started, rank_id 0 find client state exists"
153+
154+
if tcp_store_port in global_store_port_to_process:
155+
logger.warning(f"tcp store server {tcp_store_port} already started, kill and restart it")
156+
process = global_store_port_to_process[tcp_store_port]
157+
process.kill()
158+
process.join()
159+
160+
global_store_port_to_process[tcp_store_port] = start_tcp_store_server(
161+
args.config_server_host, tcp_store_port
162+
)
163+
164+
world_size_state = [True for _ in range(world_size)]
165+
global_store_port_to_client_states[tcp_store_port] = world_size_state
166+
167+
world_size_state[rank_id] = False
168+
169+
start_time = time.time()
170+
while any(world_size_state):
171+
await asyncio.sleep(1)
172+
if time.time() - start_time > 60 * 3:
173+
logger.error(
174+
f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} wait all quit timeout"
175+
)
176+
async with global_store_port_lock:
177+
global_store_port_to_client_states.pop(tcp_store_port, None)
178+
raise Exception(
179+
f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} wait timeout"
180+
)
181+
182+
async with global_store_port_lock:
183+
global_store_port_to_client_states.pop(tcp_store_port, None)
184+
185+
return {"status": "ok"}
186+
else:
187+
start_time = time.time()
188+
while tcp_store_port not in global_store_port_to_client_states:
189+
await asyncio.sleep(1)
190+
if time.time() - start_time > 60 * 3:
191+
logger.error(f"tcp store port {tcp_store_port} rank_id {rank_id} world_size {world_size} state timeout")
192+
raise Exception(
193+
f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} state timeout"
194+
)
195+
196+
world_size_state = global_store_port_to_client_states[tcp_store_port]
197+
198+
assert (
199+
world_size_state[rank_id] is True
200+
), f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} world_size_state error"
201+
world_size_state[rank_id] = False
202+
return {"status": "ok"}
203+
204+
205+
logger.info("config server start_parent_check_thread...")
206+
start_parent_check_thread()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import psutil
2+
import time
3+
import torch.distributed as dist
4+
import torch.multiprocessing as mp
5+
from lightllm.utils.log_utils import init_logger
6+
from lightllm.utils.process_check import start_parent_check_thread
7+
8+
logger = init_logger(__name__)
9+
10+
11+
def start_tcp_store_server(nccl_store_host, nccl_store_port):
12+
"""
13+
start a process to run a TCPStore server.
14+
"""
15+
process = mp.Process(
16+
target=_start_tcp_store_server,
17+
args=(nccl_store_host, nccl_store_port),
18+
daemon=True,
19+
)
20+
process.start()
21+
return process
22+
23+
24+
def _start_tcp_store_server(nccl_store_host, nccl_store_port):
25+
"""
26+
start a TCPStore server.
27+
"""
28+
start_parent_check_thread()
29+
30+
try:
31+
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
32+
33+
default_pg_nccl_timeout = _DEFAULT_PG_NCCL_TIMEOUT
34+
except ImportError:
35+
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
36+
# if anyone is actually trying to use nccl in this state, it should error.
37+
default_pg_nccl_timeout = None
38+
39+
logger.info(f"default_pg_nccl_timeout: {default_pg_nccl_timeout}")
40+
logger.info(f"[Server] TCPStore start: {nccl_store_host}:{nccl_store_port}")
41+
try:
42+
store = dist.TCPStore(
43+
host_name=nccl_store_host,
44+
port=nccl_store_port,
45+
world_size=None,
46+
is_master=True,
47+
wait_for_workers=False,
48+
timeout=default_pg_nccl_timeout,
49+
multi_tenant=True,
50+
use_libuv=True,
51+
)
52+
53+
while True:
54+
keys_num = store.num_keys()
55+
logger.info(f"[Server] TCPStore start: {nccl_store_host}:{nccl_store_port} keys num: {keys_num}")
56+
time.sleep(20)
57+
58+
except Exception as e:
59+
logger.warning(str(e))
60+
logger.info(f"TCPStore server {nccl_store_host}:{nccl_store_port} start failed, retrying ...")

lightllm/server/core/objs/start_args_type.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional, Union
2+
from typing import List, Optional, Tuple
33

44
# 只是为了更好的编程提示
55

@@ -31,7 +31,9 @@ class StartArgs:
3131
tp: int = field(default=1)
3232
dp: int = field(default=1)
3333
max_req_total_len: int = field(default=2048 + 1024)
34+
nccl_host: str = field(default="127.0.0.1")
3435
nccl_port: int = field(default=28765)
36+
use_config_server_to_init_nccl: bool = field(default=False)
3537
mode: List[str] = field(default_factory=list)
3638
trust_remote_code: bool = field(default=False)
3739
disable_log_stats: bool = field(default=False)

lightllm/utils/dist_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch.distributed as dist
22
import os
33
import torch
4+
import requests
45

56
# 规范 rank 的含义,在 llm 推理的相关代码中下述的 rank 的含义如下:
67
# global_rank 全局 rank 序列id, 如两节点 8卡,会存在 0 - 15 16个global_rank
@@ -93,6 +94,7 @@ def init_distributed_env(kvargs):
9394
dp_size_in_node = max(1, get_dp_size() // nnodes)
9495
set_dp_rank_in_node(get_global_dp_rank() % dp_size_in_node)
9596

97+
_init_nccl_env()
9698
device_id = kvargs["rank_id"] % get_node_world_size()
9799
set_current_device_id(device_id)
98100
torch.cuda.set_device(device_id)
@@ -199,3 +201,33 @@ def create_new_group_for_current_dp(backend):
199201
if get_global_dp_rank() == iter_dp_rank:
200202
ans_group = device_group
201203
return ans_group
204+
205+
206+
def _init_nccl_env():
207+
from lightllm.utils.envs_utils import get_env_start_args
208+
209+
args = get_env_start_args()
210+
211+
# 配置使用外部的 tcp store server 来创建 nccl 连接
212+
if args.use_config_server_to_init_nccl:
213+
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "True"
214+
rank_id = get_global_rank()
215+
world_size = get_global_world_size()
216+
ip_port = f"{args.config_server_host}:{args.config_server_port}"
217+
params = f"tcp_store_port={args.nccl_port}&&rank_id={rank_id}&&world_size={world_size}"
218+
219+
if rank_id == 0:
220+
# 当使用外部config server 启动的tcpStore来初始化nccl时,需要保证配置了config_server_host.
221+
# 同时也需要保证config_server_host和nccl_host是同一个ip, 这个时候 rank 0 推理进程会先调用
222+
# config server的http接口来启动tcp store server, 然后再调用nccl init方法来初始化nccl.
223+
assert args.config_server_host == args.nccl_host
224+
url = f"http://{ip_port}/start_tcp_store_server?{params}"
225+
response = requests.get(url, timeout=60 * 3)
226+
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"
227+
else:
228+
assert args.config_server_host == args.nccl_host
229+
url = f"http://{ip_port}/start_tcp_store_server?{params}"
230+
response = requests.get(url, timeout=60 * 3)
231+
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"
232+
233+
return

lightllm/utils/start_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,17 @@ def start_submodule_processes(start_funcs=[], start_args=[]):
9898
return
9999

100100

101+
def kill_recursive(proc):
102+
try:
103+
parent = psutil.Process(proc.pid)
104+
children = parent.children(recursive=True)
105+
for child in children:
106+
logger.info(f"Killing child process {child.pid}")
107+
child.kill()
108+
logger.info(f"Killing parent process {proc.pid}")
109+
parent.kill()
110+
except psutil.NoSuchProcess:
111+
logger.warning(f"Process {proc.pid} does not exist.")
112+
113+
101114
process_manager = SubmoduleManager()

0 commit comments

Comments
 (0)