Skip to content

Commit b03b60e

Browse files
authored
rpyc and zmq use unix socket. (#653)
1 parent 697cb46 commit b03b60e

File tree

14 files changed

+107
-25
lines changed

14 files changed

+107
-25
lines changed

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
1313
)
1414
parser.add_argument("--host", type=str, default="127.0.0.1")
1515
parser.add_argument("--port", type=int, default=8000)
16+
parser.add_argument(
17+
"--zmq_mode",
18+
type=str,
19+
default="ipc:///tmp/",
20+
help="use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']",
21+
)
1622

1723
parser.add_argument(
1824
"--pd_master_ip",

lightllm/server/api_start.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from lightllm.server import TokenLoad
55
from .api_lightllm import lightllm_generate, lightllm_generate_stream
66
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
7-
from lightllm.utils.net_utils import alloc_can_use_network_port
7+
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
88
from lightllm.utils.start_utils import start_submodule_processes
99
from .metrics.manager import start_metric_manager
1010
from .embed_cache.manager import start_cache_manager
@@ -27,6 +27,15 @@ def normal_or_p_d_start(g_objs):
2727
if args.run_mode not in ["normal", "prefill", "decode"]:
2828
return
2929

30+
assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
31+
32+
# 确保单机上多实列不冲突
33+
if args.zmq_mode == "ipc:///tmp/":
34+
zmq_mode = f"{args.zmq_mode}_{str(args.nccl_port)}_"
35+
args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功
36+
args.zmq_mode = zmq_mode
37+
logger.info(f"zmq mode head: {args.zmq_mode}")
38+
3039
if args.use_tgi_api:
3140
g_objs.g_generate_func = tgi_generate_impl
3241
g_objs.g_generate_stream_func = tgi_generate_stream_impl
@@ -117,9 +126,18 @@ def normal_or_p_d_start(g_objs):
117126
assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
118127

119128
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port]
129+
if args.run_mode == "decode":
130+
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port]
131+
132+
# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
133+
# 捕获到端口设置冲突的问题
134+
ports_locker = PortLocker(already_uesd_ports)
135+
ports_locker.lock_port()
136+
120137
can_use_ports = alloc_can_use_network_port(
121138
num=6 + args.tp + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
122139
)
140+
logger.info(f"alloced ports: {can_use_ports}")
123141
router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
124142
model_rpc_ports = can_use_ports[6 : 6 + args.tp]
125143
can_use_ports = can_use_ports[6 + args.tp :]
@@ -144,6 +162,8 @@ def normal_or_p_d_start(g_objs):
144162

145163
logger.info(f"all start args:{args}")
146164

165+
ports_locker.release_port()
166+
147167
if args.enable_multimodal:
148168
start_submodule_processes(
149169
start_funcs=[

lightllm/server/detokenization/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def __init__(
3333
self.args = args
3434
context = zmq.asyncio.Context(2)
3535
self.recv_from_router = context.socket(zmq.PULL)
36-
self.recv_from_router.bind(f"tcp://127.0.0.1:{detokenization_port}")
36+
self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
3737

3838
self.send_to_httpserver = context.socket(zmq.PUSH)
39-
self.send_to_httpserver.connect(f"tcp://127.0.0.1:{httpserver_port}")
39+
self.send_to_httpserver.connect(f"{args.zmq_mode}127.0.0.1:{httpserver_port}")
4040

4141
self.tokenizer = get_tokenizer(model_weightdir, tokenizor_mode, trust_remote_code=trust_remote_code)
4242
self.all_special_ids = set(self.tokenizer.all_special_ids)

lightllm/server/httpserver/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ def __init__(
4545
self.args = args
4646
context = zmq.asyncio.Context(2)
4747
self.send_to_router = context.socket(zmq.PUSH)
48-
self.send_to_router.connect(f"tcp://127.0.0.1:{router_port}")
48+
self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}")
4949

5050
self.enable_multimodal = enable_multimodal
5151
if self.enable_multimodal:
5252
self.cache_client = rpyc.connect("localhost", cache_port)
5353
self.send_to_visual = context.socket(zmq.PUSH)
54-
self.send_to_visual.connect(f"tcp://127.0.0.1:{visual_port}")
54+
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
5555

5656
self.recv_from_detokenization = context.socket(zmq.PULL)
57-
self.recv_from_detokenization.bind(f"tcp://127.0.0.1:{httpserver_port}")
57+
self.recv_from_detokenization.bind(f"{args.zmq_mode}127.0.0.1:{httpserver_port}")
5858

5959
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
6060

lightllm/server/router/manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
6666

6767
context = zmq.asyncio.Context(2)
6868
self.recv_from_httpserver = context.socket(zmq.PULL)
69-
self.recv_from_httpserver.bind(f"tcp://127.0.0.1:{router_port}")
69+
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}")
7070

7171
self.send_to_detokenization = context.socket(zmq.PUSH)
72-
self.send_to_detokenization.connect(f"tcp://127.0.0.1:{detokenization_port}")
72+
self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
7373
self.model_rpc_ports = model_rpc_ports
7474

7575
self.is_splitfuse_mode = args.splitfuse_mode
@@ -283,14 +283,15 @@ async def _step(self):
283283
self.running_batch = new_batch
284284
await self._prefill_batch(self.running_batch)
285285
self._filter_runing_batch()
286-
self.has_wait_tokens = 0
286+
self.has_wait_tokens = self.max_wait_tokens
287287
return
288288

289289
# 有运行请求,但是已经到了可以调度新的请求合并推理的时机
290290
if self.has_wait_tokens >= self.max_wait_tokens:
291291
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
292292
self.has_wait_tokens = 0
293293
if new_mini_batch is not None:
294+
self.has_wait_tokens = self.max_wait_tokens
294295
self.stats_tool.count_prompt_tokens(new_mini_batch)
295296
await self._prefill_batch(new_mini_batch)
296297
if not new_mini_batch.is_clear():

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import torch.multiprocessing as mp
34
import torch.distributed as dist
@@ -30,7 +31,13 @@ def init_custom(self):
3031
self.lock_nccl_group = dist.new_group(backend="gloo")
3132
from .decode_infer_rpyc import PDDecodeInferRpcServer
3233

33-
t = ThreadedServer(PDDecodeInferRpcServer(self), port=self.pd_rpyc_port, protocol_config={"allow_pickle": True})
34+
socket_path = f"/tmp/decode_node_infer_rpyc_{self.pd_rpyc_port}"
35+
if os.path.exists(socket_path):
36+
os.remove(socket_path)
37+
38+
t = ThreadedServer(
39+
PDDecodeInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True}
40+
)
3441
threading.Thread(target=lambda: t.start(), daemon=True).start()
3542
return
3643

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_kv_move_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
8787
self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = []
8888
self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {}
8989
for port in self.args.pd_tp_infer_rpyc_ports:
90-
con = retry(max_attempts=20, wait_time=2)(rpyc.connect)("localhost", port, config={"allow_pickle": True})
90+
socket_path = f"/tmp/decode_node_infer_rpyc_{port}"
91+
from rpyc.utils.factory import unix_connect
92+
93+
con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True})
9194
self.infer_rpyc_objs.append(con.root)
9295
logger.info(f"rpyc connect to port: {port} ok")
9396

lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_impl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import threading
23
import torch
34
import torch.multiprocessing as mp
@@ -29,8 +30,12 @@ def init_custom(self):
2930
self.lock_nccl_group = dist.new_group(backend="gloo")
3031
from .prefill_infer_rpyc import PDPrefillInferRpcServer
3132

33+
socket_path = f"/tmp/prefill_node_infer_rpyc_{self.pd_rpyc_port}"
34+
if os.path.exists(socket_path):
35+
os.remove(socket_path)
36+
3237
t = ThreadedServer(
33-
PDPrefillInferRpcServer(self), port=self.pd_rpyc_port, protocol_config={"allow_pickle": True}
38+
PDPrefillInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True}
3439
)
3540
threading.Thread(target=lambda: t.start(), daemon=True).start()
3641
return

lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_kv_move_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
9999
self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = []
100100
self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {}
101101
for port in self.args.pd_tp_infer_rpyc_ports:
102-
con = retry(max_attempts=20, wait_time=2)(rpyc.connect)("localhost", port, config={"allow_pickle": True})
102+
socket_path = f"/tmp/prefill_node_infer_rpyc_{port}"
103+
from rpyc.utils.factory import unix_connect
104+
105+
con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True})
103106
self.infer_rpyc_objs.append(con.root)
104107
logger.info(f"rpyc connect to infer rpyc port: {port} ok")
105108
self.host_ip = get_hostname_ip()

lightllm/server/router/model_infer/model_rpc.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import rpyc
3-
import torch
3+
import tempfile
44
import torch.multiprocessing as mp
55
from datetime import timedelta
66
from typing import Dict, List, Tuple
@@ -243,7 +243,7 @@ async def get_max_total_token_num(self):
243243
return ans
244244

245245

246-
def _init_env(args, port, info_queue, mem_queue, router_lock):
246+
def _init_env(args, socket_path, info_queue, mem_queue, router_lock, success_event: mp.Event):
247247
import lightllm.utils.rpyc_fix_utils as _
248248

249249
# 注册graceful 退出的处理
@@ -259,7 +259,10 @@ def _init_env(args, port, info_queue, mem_queue, router_lock):
259259

260260
from rpyc.utils.server import ThreadedServer
261261

262-
t = ThreadedServer(ModelRpcServer(args, info_queue, mem_queue), port=port, protocol_config={"allow_pickle": True})
262+
t = ThreadedServer(
263+
ModelRpcServer(args, info_queue, mem_queue), socket_path=socket_path, protocol_config={"allow_pickle": True}
264+
)
265+
success_event.set()
263266
t.start()
264267
return
265268

@@ -271,13 +274,18 @@ async def start_model_process(args, port, world_size, info_queue: mp.Queue, mem_
271274
if world_size == 1:
272275
return ModelRpcClient(ModelRpcServer(args, info_queue, mem_queue), world_size)
273276

274-
proc = mp.Process(target=_init_env, args=(args, port, info_queue, mem_queue, router_lock))
277+
socket_path = tempfile.mktemp()
278+
success_event = mp.Event()
279+
proc = mp.Process(target=_init_env, args=(args, socket_path, info_queue, mem_queue, router_lock, success_event))
275280
proc.start()
276-
await asyncio.sleep(2)
281+
success_event.wait(timeout=40)
282+
277283
repeat_count = 0
278284
while repeat_count < 20:
279285
try:
280-
con = rpyc.connect("localhost", port, config={"allow_pickle": True})
286+
from rpyc.utils.factory import unix_connect
287+
288+
con = unix_connect(socket_path, config={"allow_pickle": True})
281289
break
282290
except BaseException:
283291
await asyncio.sleep(1)

0 commit comments

Comments
 (0)