Skip to content

Commit 5e4073e

Browse files
committed
fix
1 parent 0795d72 commit 5e4073e

File tree

14 files changed

+20
-58
lines changed

14 files changed

+20
-58
lines changed

lightllm/common/mem_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import torch
44
import torch.distributed as dist
5+
import torch.multiprocessing as mp
56
from typing import List, Union
67
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp
78
from lightllm.server.pd_io_struct import KVMoveTask
@@ -435,6 +436,10 @@ def write_to_shm(self):
435436
"""
436437
将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。
437438
"""
439+
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor
440+
441+
mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__
442+
438443
shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}"
439444
obj_bytes = ForkingPickler.dumps(self).tobytes()
440445
shm = create_or_link_shm(name=shm_name, expected_size=len(obj_bytes) + 4, force_mode="create")

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@ def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]):
166166
release_acquired_lock()
167167
return
168168

169-
def exposed_put_mem_manager_to_shm(self):
170-
self.backend.model.mem_manager.create_shm()
171-
logger.info("put mem manager to shm ok")
172-
return
173-
174169
def exposed_unfrozen_time_out_reqs_tokens(self):
175170
acquire_lock_until_ready(self.backend.lock_nccl_group)
176171
if self.backend.dp_world_size == 1:

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def __init__(self, args, info_queue: mp.Queue):
8686
# _put_kv_received_to_radix_cache
8787
# _fail_to_realese_forzen_tokens
8888
# _unfrozen_time_out_reqs_tokens
89-
# _put_mem_manager_to_shm
9089
# 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放
9190
# kv资源的接口
9291
# ==================================================================================
@@ -154,12 +153,6 @@ def _unfrozen_time_out_reqs_tokens(self) -> None:
154153
asyncio.run(self.wait_all_future_finish(futures))
155154
return
156155

157-
def _put_mem_manager_to_shm(self) -> None:
158-
with self.infer_rpyc_lock:
159-
for obj in self.infer_rpyc_objs:
160-
obj.put_mem_manager_to_shm()
161-
return
162-
163156
# ==================================================================================
164157
# put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到
165158
# 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,6 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"):
281281
self.task_out_queue,
282282
)
283283
assert self.task_out_queue.get(timeout=30) == "proc_start"
284-
# 确保在子进程读取共享内存之前,主进程已经将 mem_manager 写入共享内存
285-
if self.device_id == 0:
286-
manager._put_mem_manager_to_shm()
287-
# 通知子进程可以从共享内存读取 mem_manager
288-
self.task_in_queue.put("mem_managers_ready")
289284
assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok"
290285

291286
return True

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,11 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.
111111
graceful_registry(inspect.currentframe().f_code.co_name)
112112
task_out_queue.put("proc_start")
113113

114-
# 等待主进程将 mem_manager 写入共享内存后的信号
115-
assert task_in_queue.get(timeout=60) == "mem_managers_ready"
116-
117114
# 从共享内存读取所有rank的mem_manager
118115
node_world_size = args.tp // args.nnodes
119-
mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)]
116+
mem_managers: List[MemoryManager] = [
117+
MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size)
118+
]
120119

121120
task_out_queue.put("get_mem_managers_ok")
122121
connect_id_to_comm: Dict[str, PyNcclCommunicator] = {}

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,3 @@ def exposed_remove_req_refs_from_prompt_cache(self, group_req_ids: List[int]):
4545
)
4646
release_acquired_lock()
4747
return
48-
49-
def exposed_put_mem_manager_to_shm(self):
50-
self.backend.model.mem_manager.create_shm()
51-
logger.info("put mem manager to shm ok")
52-
return

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def check_trans_process_loop(self):
142142
raise e
143143

144144
# ==================================================================================
145-
# 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和
146-
# _put_mem_manager_to_shm 都是通过 rpyc 与推理进程进行交互的接口
145+
# 与推理进程交互接口, _remove_req_refs_from_prompt_cache
147146
# ==================================================================================
148147

149148
def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]):
@@ -163,12 +162,6 @@ def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]):
163162
asyncio.run(self.wait_all_future_finish(futures))
164163
return
165164

166-
def _put_mem_manager_to_shm(self):
167-
with self.infer_rpyc_lock:
168-
for obj in self.infer_rpyc_objs:
169-
obj.put_mem_manager_to_shm()
170-
return
171-
172165
async def wait_all_future_finish(self, futures: List[AsyncResult]):
173166
await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures])
174167
return

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,6 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"):
355355
self.task_out_queue,
356356
)
357357
assert self.task_out_queue.get(timeout=30) == "proc_start"
358-
if self.device_id == 0:
359-
manager._put_mem_manager_to_shm()
360-
self.task_in_queue.put("mem_managers_ready")
361358
assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok"
362359

363360
return True

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,11 @@ def _init_env(
116116
)
117117
task_out_queue.put("proc_start")
118118

119-
# 等待主进程将 mem_manager 写入共享内存后的信号
120-
assert task_in_queue.get(timeout=60) == "mem_managers_ready"
121-
122119
# 从共享内存读取所有rank的mem_manager
123120
node_world_size = args.tp // args.nnodes
124-
mem_managers: List[MemoryManager] = [MemoryManager.from_shm(rank, device_id) for rank in range(node_world_size)]
121+
mem_managers: List[MemoryManager] = [
122+
MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size)
123+
]
125124
task_out_queue.put("get_mem_managers_ok")
126125
connect_id_to_comm: Dict[str, PyNcclCommunicator] = {}
127126

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids
2828
from .control_state import DPControlState
2929
from lightllm.common.mem_manager import MemoryManager
30-
import torch.multiprocessing as mp
3130

3231
min_trans_token_num = int(os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", "512"))
3332
dp_kv_transfer_req_num = int(os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", "16"))
@@ -74,15 +73,9 @@ def init_custom(self):
7473
if self.enable_dp_prompt_cache_fetch:
7574
torch.cuda.set_device(get_current_device_id())
7675

77-
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor
7876
from lightllm.server.core.objs.shm_array import ShmArray
7977
from lightllm.utils.envs_utils import get_unique_server_name
8078

81-
mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__
82-
83-
# Create shared memory for mem_manager
84-
self.model.mem_manager.create_shm(use_for_pd_trans=False)
85-
8679
# Create shared ShmArray for kv_indexes transfer
8780
# Use a small buffer to save shared memory
8881
self.dp_kv_transfer_req_num = dp_kv_transfer_req_num
@@ -101,9 +94,7 @@ def init_custom(self):
10194
self.mem_managers = []
10295
for rank_idx in range(self.node_world_size):
10396
if rank_idx != self.rank_in_node:
104-
self.mem_managers.append(
105-
MemoryManager.from_shm(rank_idx, self.rank_in_node, use_for_pd_trans=False)
106-
)
97+
self.mem_managers.append(MemoryManager.loads_from_shm(self.rank_in_node))
10798
else:
10899
self.mem_managers.append(self.model.mem_manager)
109100

0 commit comments

Comments
 (0)