Skip to content

Commit fd0511e

Browse files
committed
replace mem_queue with shm
1 parent 5e7f2d9 commit fd0511e

File tree

20 files changed

+89
-93
lines changed

20 files changed

+89
-93
lines changed

lightllm/common/mem_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
from lightllm.utils.dist_utils import get_current_device_id
1616
from lightllm.utils.config_utils import get_num_key_value_heads
1717
from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io
18+
from lightllm.utils.shm_utils import create_or_link_shm
19+
from multiprocessing.reduction import ForkingPickler
1820

1921
logger = init_logger(__name__)
22+
LIGHTLLM_MEM_MANAGER_SHM_SIZE = int(os.getenv("LIGHTLLM_MEM_MANAGER_SHM_SIZE", 1024 * 1024))
2023

2124

2225
class MemoryManager:
@@ -431,6 +434,23 @@ def copy_kv_from_other_dp_ranks(
431434
rank_in_dp=rank_in_dp,
432435
)
433436

437+
def create_shm(self):
438+
obj_bytes = ForkingPickler.dumps(self)
439+
shm = create_or_link_shm(
440+
f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}", LIGHTLLM_MEM_MANAGER_SHM_SIZE
441+
)
442+
logger.info(f"create shm {shm.name} size {shm.size} obj size {len(obj_bytes)}")
443+
shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little")
444+
shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes
445+
446+
@staticmethod
447+
def from_shm(rank_in_node):
448+
shm = create_or_link_shm(
449+
f"{get_unique_server_name()}_mem_manager_{rank_in_node}", LIGHTLLM_MEM_MANAGER_SHM_SIZE
450+
)
451+
bytes_len = int.from_bytes(shm.buf[0:4], "little")
452+
return ForkingPickler.loads(shm.buf[4 : 4 + bytes_len])
453+
434454

435455
class ReadOnlyStaticsMemoryManager:
436456
"""

lightllm/server/router/manager.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ async def wait_to_model_ready(self):
116116
self.model_rpc_servers = []
117117
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
118118
self.info_queue: mp.Queue = mp.Queue()
119-
self.mem_queues: List[torch.multiprocessing.Queue] = [
120-
torch.multiprocessing.Queue() for _ in range(self.node_world_size)
121-
]
122119
self.rpc_event = multiprocessing.Event()
123120
self.rpc_finished_event = multiprocessing.Event()
124121

@@ -137,9 +134,7 @@ async def wait_to_model_ready(self):
137134
rpc_event=self.rpc_event,
138135
rpc_finished_event=self.rpc_finished_event,
139136
info_queue=self.info_queue,
140-
mem_queue=self.mem_queues[(rank_id % node_world_size)],
141137
router_lock=self.router_lock,
142-
mem_queues=self.mem_queues,
143138
)
144139
)
145140
tasks.append(task)
@@ -206,29 +201,29 @@ async def wait_to_model_ready(self):
206201
start_prefill_kv_move_manager_process,
207202
)
208203

209-
start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
204+
start_prefill_kv_move_manager_process(self.args, self.info_queue)
210205

211206
if self.args.run_mode == "nixl_prefill":
212207
from lightllm.server.router.model_infer.mode_backend.pd_nixl.prefill_node_impl import (
213208
start_prefill_kv_move_manager_process,
214209
)
215210

216-
start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
211+
start_prefill_kv_move_manager_process(self.args, self.info_queue)
217212

218213
if self.args.run_mode == "decode":
219214
# 启动 decode kv move 管理进程
220215
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import (
221216
start_decode_kv_move_manager_process,
222217
)
223218

224-
start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
219+
start_decode_kv_move_manager_process(self.args, self.info_queue)
225220

226221
if self.args.run_mode == "nixl_decode":
227222
from lightllm.server.router.model_infer.mode_backend.pd_nixl.decode_node_impl import (
228223
start_decode_kv_move_manager_process,
229224
)
230225

231-
start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
226+
start_decode_kv_move_manager_process(self.args, self.info_queue)
232227

233228
return
234229

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919

2020

2121
class DecodeNode(ChunkedPrefillBackend):
22-
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
22+
def __init__(self, info_queue: mp.Queue) -> None:
2323
super().__init__()
2424
self.info_queue: mp.Queue = info_queue
25-
self.mem_queue: mp.Queue = mem_queue
2625
self.classed_req_strict_prefill = False
2726

2827
def init_custom(self):

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
class DPForDecodeNode(DPChunkedPrefillBackend):
12-
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
13-
super().__init__(mem_queue=mem_queue)
12+
def __init__(self, info_queue: mp.Queue) -> None:
13+
super().__init__()
1414
self.info_queue: mp.Queue = info_queue
1515
self.classed_req_strict_prefill = False
1616
return

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]):
166166
release_acquired_lock()
167167
return
168168

169-
def exposed_put_mem_manager_to_mem_queue(self):
170-
self.backend.mem_queue.put(self.backend.model.mem_manager)
171-
logger.info("put mem manager to info_queues ok")
169+
def exposed_put_mem_manager_to_shm(self):
170+
self.backend.model.mem_manager.create_shm()
171+
logger.info("put mem manager to shm ok")
172172
return
173173

174174
def exposed_unfrozen_time_out_reqs_tokens(self):

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class DecodeKVMoveManager(rpyc.Service):
36-
def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
36+
def __init__(self, args, info_queue: mp.Queue):
3737
super().__init__()
3838
self.args = args
3939
# args.dp // args.nnodes 在跨机tp的场景下,可能为0
@@ -44,7 +44,6 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
4444
assert self.dp_world_size <= self.node_world_size
4545

4646
self.info_queue = info_queue
47-
self.mem_queues = mem_queues
4847
self.infer_rpyc_lock = threading.Lock()
4948
self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = []
5049

@@ -87,7 +86,7 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
8786
# _put_kv_received_to_radix_cache
8887
# _fail_to_realese_forzen_tokens
8988
# _unfrozen_time_out_reqs_tokens
90-
# _put_mem_manager_to_mem_queue
89+
# _put_mem_manager_to_shm
9190
# 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放
9291
# kv资源的接口
9392
# ==================================================================================
@@ -155,10 +154,10 @@ def _unfrozen_time_out_reqs_tokens(self) -> None:
155154
asyncio.run(self.wait_all_future_finish(futures))
156155
return
157156

158-
def _put_mem_manager_to_mem_queue(self) -> None:
157+
def _put_mem_manager_to_shm(self) -> None:
159158
with self.infer_rpyc_lock:
160159
for obj in self.infer_rpyc_objs:
161-
obj.put_mem_manager_to_mem_queue()
160+
obj.put_mem_manager_to_shm()
162161
return
163162

164163
# ==================================================================================
@@ -362,14 +361,14 @@ def remove_trans_obj(self, connect_id):
362361
return
363362

364363

365-
def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event):
364+
def _init_env(args, info_queue: mp.Queue, event: mp.Event):
366365
import lightllm.utils.rpyc_fix_utils as _
367366

368367
# 注册graceful 退出的处理
369368
graceful_registry(inspect.currentframe().f_code.co_name)
370369
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::decode_kv_move_manager")
371370

372-
manager = DecodeKVMoveManager(args, info_queue, mem_queues)
371+
manager = DecodeKVMoveManager(args, info_queue)
373372
t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True})
374373
threading.Thread(target=lambda: t.start(), daemon=True).start()
375374

@@ -381,9 +380,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.
381380
return
382381

383382

384-
def start_decode_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
383+
def start_decode_kv_move_manager_process(args, info_queue: mp.Queue):
385384
event = mp.Event()
386-
proc = mp.Process(target=_init_env, args=(args, info_queue, mem_queues, event))
385+
proc = mp.Process(target=_init_env, args=(args, info_queue, event))
387386
proc.start()
388387
event.wait()
389388
assert proc.is_alive()

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,9 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"):
279279
device_id,
280280
self.task_in_queue,
281281
self.task_out_queue,
282-
manager.mem_queues,
283282
)
284283
assert self.task_out_queue.get(timeout=30) == "proc_start"
285-
manager._put_mem_manager_to_mem_queue()
284+
manager._put_mem_manager_to_shm()
286285
assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok"
287286

288287
return True

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def async_connect():
9191
logger.warning(f"error while connect to prefill node: {e}")
9292

9393

94-
def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]):
94+
def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue):
9595
import os
9696

9797
# os.environ["NCCL_DEBUG"] = "INFO"
@@ -111,7 +111,9 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.
111111
graceful_registry(inspect.currentframe().f_code.co_name)
112112
task_out_queue.put("proc_start")
113113

114-
mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues]
114+
# 从共享内存读取所有rank的mem_manager
115+
node_world_size = args.tp // args.nnodes
116+
mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank) for rank in range(node_world_size)]
115117

116118
task_out_queue.put("get_mem_managers_ok")
117119
connect_id_to_comm: Dict[str, PyNcclCommunicator] = {}
@@ -143,9 +145,8 @@ def start_decode_trans_process(
143145
device_id: int,
144146
task_in_queue: mp.Queue,
145147
task_out_queue: mp.Queue,
146-
mem_queues: List[mp.Queue],
147148
):
148-
proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues))
149+
proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue))
149150
proc.start()
150151
assert proc.is_alive()
151152
logger.info(f"decode trans kv process for device: {device_id} start!")

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020

2121

2222
class ChunckedPrefillForPrefillNode(ChunkedPrefillBackend):
23-
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
23+
def __init__(self, info_queue: mp.Queue) -> None:
2424
super().__init__()
2525
self.support_overlap = False
2626
self.info_queue: mp.Queue = info_queue
27-
self.mem_queue: mp.Queue = mem_queue
2827
self.classed_req_no_decode = True
2928

3029
def init_custom(self):

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010

1111
class DPChunkedForPrefillNode(DPChunkedPrefillBackend):
12-
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue, mem_queues: List[mp.Queue]) -> None:
13-
super().__init__(mem_queue=mem_queue, mem_queues=mem_queues)
12+
def __init__(self, info_queue: mp.Queue) -> None:
13+
super().__init__()
1414
self.support_overlap = False
1515
self.info_queue: mp.Queue = info_queue
1616
self.classed_req_no_decode = True

0 commit comments

Comments
 (0)