Skip to content

Commit 4cc6d08

Browse files
author
liujiacheng
committed
fix
1 parent 5437ee7 commit 4cc6d08

File tree

6 files changed

+53
-44
lines changed

6 files changed

+53
-44
lines changed

lightllm/server/api_start.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def normal_or_p_d_start(args):
217217

218218
node_world_size = args.tp // args.nnodes
219219
can_use_ports = alloc_can_use_network_port(
220-
num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
220+
num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
221221
)
222222
logger.info(f"alloced ports: {can_use_ports}")
223223
(
@@ -228,8 +228,9 @@ def normal_or_p_d_start(args):
228228
audio_port,
229229
cache_port,
230230
metric_port,
231-
) = can_use_ports[0:7]
232-
can_use_ports = can_use_ports[7:]
231+
multi_level_kv_cache_port,
232+
) = can_use_ports[0:8]
233+
can_use_ports = can_use_ports[8:]
233234

234235
visual_model_tp_ports = []
235236
for _ in range(args.visual_dp):
@@ -245,6 +246,7 @@ def normal_or_p_d_start(args):
245246
args.audio_port = audio_port
246247
args.cache_port = cache_port
247248
args.metric_port = metric_port
249+
args.multi_level_kv_cache_port = multi_level_kv_cache_port
248250

249251
# 申请在 p d 分离模式下,会用的端口
250252
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]

lightllm/server/audioserver/manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ def __init__(
2626
infer_batch_size=4,
2727
):
2828
context = zmq.asyncio.Context(2)
29-
self.send_to_router = context.socket(zmq.PUSH)
30-
self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}")
29+
30+
if args.enable_cpu_cache:
31+
self.send_to_next_module = context.socket(zmq.PUSH)
32+
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}")
33+
else:
34+
self.send_to_next_module = context.socket(zmq.PUSH)
35+
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}")
3136

3237
self.zmq_recv_socket = context.socket(zmq.PULL)
3338
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.audio_port}")
@@ -87,7 +92,7 @@ async def loop_for_fwd(self):
8792
# 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理
8893
# 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了
8994
# 需要一些一致的流程来保证不出现异步问题。
90-
self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
95+
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
9196
continue
9297

9398
multimodal_params = group_req_indexes.multimodal_params
@@ -103,18 +108,20 @@ async def loop_for_fwd(self):
103108
await self.infer_audios(audios_need_infer)
104109
audios_need_infer = []
105110
for _group_req_indexes in processing_group_reqs:
106-
self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
111+
self.send_to_next_module.send_pyobj(
112+
_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL
113+
)
107114
processing_group_reqs = []
108115

109116
if len(audios_need_infer) == 0:
110-
self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
117+
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
111118
else:
112119
processing_group_reqs.append(group_req_indexes)
113120

114121
if len(audios_need_infer) > 0:
115122
await self.infer_audios(audios_need_infer)
116123
for _group_req_indexes in processing_group_reqs:
117-
self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
124+
self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
118125
processing_group_reqs = []
119126
audios_need_infer = []
120127

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class StartArgs:
110110
cache_port: int = field(default=None)
111111
metric_port: int = field(default=None)
112112
multinode_httpmanager_port: int = field(default=12345)
113+
multi_level_kv_cache_port: int = field(default=None)
113114
# multi_modal
114115
enable_multimodal: bool = field(default=False)
115116
enable_multimodal_audio: bool = field(default=False)

lightllm/server/httpserver/manager.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def __init__(
7979
self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True})
8080
self.send_to_visual = context.socket(zmq.PUSH)
8181
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}")
82+
if args.enable_cpu_cache and not self.args.enable_multimodal:
83+
self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH)
84+
self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}")
8285

8386
self.shm_req_manager = ShmReqManager()
8487

@@ -432,38 +435,33 @@ async def transfer_to_next_module(
432435
group_req_objs: Optional[GroupReqObjs] = None,
433436
):
434437

435-
if self.pd_mode == NodeRole.P:
438+
if self.pd_mode.is_P_or_NORMAL():
436439
if self.enable_multimodal:
437440
self.send_to_visual.send_pyobj(
438441
group_req_objs.to_group_req_index(),
439442
protocol=pickle.HIGHEST_PROTOCOL,
440443
)
441-
else:
442-
self.send_to_router.send_pyobj(
444+
return
445+
446+
if self.args.enable_cpu_cache:
447+
self.send_to_multi_level_kv_cache.send_pyobj(
443448
group_req_objs.to_group_req_index(),
444449
protocol=pickle.HIGHEST_PROTOCOL,
445450
)
446-
return
451+
return
447452

448-
if self.pd_mode == NodeRole.D:
449-
# 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可
450453
self.send_to_router.send_pyobj(
451454
group_req_objs.to_group_req_index(),
452455
protocol=pickle.HIGHEST_PROTOCOL,
453456
)
454457
return
455458

456-
if self.pd_mode == NodeRole.NORMAL:
457-
if self.enable_multimodal:
458-
self.send_to_visual.send_pyobj(
459-
group_req_objs.to_group_req_index(),
460-
protocol=pickle.HIGHEST_PROTOCOL,
461-
)
462-
else:
463-
self.send_to_router.send_pyobj(
464-
group_req_objs.to_group_req_index(),
465-
protocol=pickle.HIGHEST_PROTOCOL,
466-
)
459+
if self.pd_mode == NodeRole.D:
460+
# 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了
461+
self.send_to_router.send_pyobj(
462+
group_req_objs.to_group_req_index(),
463+
protocol=pickle.HIGHEST_PROTOCOL,
464+
)
467465
return
468466

469467
assert False, "dead code path"

lightllm/server/multi_level_kv_cache/manager.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uvloop
22
import asyncio
3+
import collections
34

45
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
56
import zmq
@@ -8,7 +9,7 @@
89
import time
910
import threading
1011
import concurrent.futures
11-
from typing import List
12+
from typing import List, Deque
1213
from lightllm.server.core.objs import ShmReqManager, Req, StartArgs
1314
from lightllm.server.core.objs.io_objs import GroupReqIndexes
1415
from lightllm.utils.graceful_utils import graceful_registry
@@ -21,18 +22,16 @@
2122
class MultiLevelKVCacheManager:
2223
def __init__(
2324
self,
24-
args,
25-
detokenization_port,
26-
router_port,
25+
args: StartArgs,
2726
):
2827
self.args: StartArgs = args
2928
context = zmq.Context(2)
30-
self.recv_from_pre_module = context.socket(zmq.PULL)
31-
self.recv_from_pre_module.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
29+
self.zmq_recv_socket = context.socket(zmq.PULL)
30+
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}")
3231

3332
self.send_to_router = context.socket(zmq.PUSH)
34-
self.send_to_router.bind(f"{args.zmq_mode}127.0.0.1:{router_port}")
35-
logger.info(f"pub_to_httpserver sendhwm {self.send_to_router.getsockopt(zmq.SNDHWM)}")
33+
self.send_to_router.bind(f"{args.zmq_mode}127.0.0.1:{args.router_port}")
34+
logger.info(f"send_to_router sendhwm {self.send_to_router.getsockopt(zmq.SNDHWM)}")
3635
self.cpu_cache_client = CpuKvCacheClient(init_shm_data=True)
3736
self.shm_req_manager = ShmReqManager()
3837
# 控制同时进行cpu cache 匹配操作的数量。
@@ -42,7 +41,7 @@ def __init__(
4241
self.cpu_cache_time_out = 0.3
4342
# lock 用于控制对 recv_queue 和 transfer_queue 的访问。
4443
self.queue_lock = threading.Lock()
45-
self.recv_queue: List[GroupReqIndexes] = []
44+
self.recv_queue: Deque[GroupReqIndexes] = collections.deque()
4645
self.transfer_queue: List[GroupReqIndexes] = []
4746
self.transfer_thread = threading.Thread(target=self.transfer_loop, daemon=True)
4847
self.transfer_thread.start()
@@ -58,8 +57,7 @@ def cpu_cache_hanle_loop(self):
5857
continue
5958

6059
with self.queue_lock:
61-
current_group_req = self.recv_queue[0]
62-
self.recv_queue = self.recv_queue[1:]
60+
current_group_req = self.recv_queue.popleft()
6361

6462
self.executor.submit(self._handle_group_req_cpu_cache_match, current_group_req, time.time())
6563
except BaseException as e:
@@ -146,7 +144,7 @@ def recv_loop(self):
146144
try:
147145
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
148146
for _ in range(recv_max_count):
149-
recv_obj: GroupReqIndexes = self.recv_from_pre_module.recv_pyobj(zmq.NOBLOCK)
147+
recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
150148
assert isinstance(recv_obj, GroupReqIndexes)
151149
recv_objs.append(recv_obj)
152150

@@ -166,15 +164,13 @@ def recv_loop(self):
166164
return
167165

168166

169-
def start_detokenization_process(args, detokenization_port, router_port, pipe_writer):
167+
def start_multi_level_kv_cache_manager(args, pipe_writer):
170168
# 注册graceful 退出的处理
171169
graceful_registry(inspect.currentframe().f_code.co_name)
172170

173171
try:
174172
manager = MultiLevelKVCacheManager(
175173
args=args,
176-
detokenization_port=detokenization_port,
177-
router_port=router_port,
178174
)
179175
except Exception as e:
180176
pipe_writer.send(str(e))

lightllm/server/visualserver/manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,17 @@ def __init__(
2828
visual_model_rpc_ports,
2929
):
3030
context = zmq.Context(2)
31+
3132
if args.enable_multimodal_audio:
32-
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
33+
self.send_to_next_module = context.socket(zmq.PUSH)
3334
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}")
3435
else:
35-
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
36-
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}")
36+
if args.enable_cpu_cache:
37+
self.send_to_next_module = context.socket(zmq.PUSH)
38+
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}")
39+
else:
40+
self.send_to_next_module = context.socket(zmq.PUSH)
41+
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}")
3742

3843
self.zmq_recv_socket = context.socket(zmq.PULL)
3944
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}")

0 commit comments

Comments
 (0)