Skip to content

Commit 0e13fb9

Browse files
author
wangzaijun
committed
fix
1 parent 1fe6cfa commit 0e13fb9

File tree

6 files changed

+34
-23
lines changed

6 files changed

+34
-23
lines changed

lightllm/common/kv_cache_mem_manager/mem_manager.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from lightllm.utils.device_utils import kv_trans_use_p2p
2020
from lightllm.utils.shm_utils import create_or_link_shm
2121
from multiprocessing.reduction import ForkingPickler
22+
from filelock import FileLock
2223

2324
logger = init_logger(__name__)
2425

@@ -450,25 +451,39 @@ def write_to_shm(self, req_manager):
450451
# 避免过多无用的数据复制和传输开销。
451452
self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs
452453

453-
shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}"
454-
for rank_in_node in range(0, get_node_world_size() * 2):
455-
obj_bytes = ForkingPickler.dumps(self).tobytes()
454+
lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock")
455+
with lock:
456+
node_world_size = get_node_world_size()
457+
shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}"
458+
obj_bytes_array = [ForkingPickler.dumps(self).tobytes() for _ in range(node_world_size * 2)]
459+
obj_size = len(obj_bytes_array[0])
456460
shm = create_or_link_shm(
457-
name=f"{shm_name}_{rank_in_node}", expected_size=len(obj_bytes) + 4, force_mode="create"
461+
name=shm_name, expected_size=obj_size * (node_world_size * 2) + 4 + 4, force_mode="create"
458462
)
459463
logger.info(f"create shm {shm.name} size {shm.size} for mem manger shared buffer")
460-
shm.buf[0:4] = len(obj_bytes).to_bytes(4, "little")
461-
shm.buf[4 : 4 + len(obj_bytes)] = obj_bytes
464+
shm.buf[0:4] = (node_world_size * 2).to_bytes(4, "little")
465+
shm.buf[4:8] = obj_size.to_bytes(4, "little")
466+
start_index = 8
467+
for obj_bytes in obj_bytes_array:
468+
shm.buf[start_index : start_index + obj_size] = obj_bytes
469+
start_index += obj_size
462470

463471
@staticmethod
464-
def loads_from_shm(rank_in_node: int, current_rank_in_node: int) -> "MemoryManager":
465-
shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}_{current_rank_in_node}"
472+
def loads_from_shm(rank_in_node: int) -> "MemoryManager":
473+
shm_name = f"{get_unique_server_name()}_mem_manager_{rank_in_node}"
474+
lock = FileLock(f"/tmp/{get_unique_server_name()}_mem_manager_lock")
466475
logger.info(f"get memmanager from shm {shm_name}")
467-
shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link")
468-
bytes_len = int.from_bytes(shm.buf[0:4], "little")
469-
obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes()
470-
shm.close()
471-
return ForkingPickler.loads(obj_bytes)
476+
with lock:
477+
shm = create_or_link_shm(name=shm_name, expected_size=-1, force_mode="link")
478+
left_num = int.from_bytes(shm.buf[0:4], "little")
479+
obj_size = int.from_bytes(shm.buf[4:8], "little")
480+
assert left_num > 0
481+
end_index = 8 + left_num * obj_size
482+
start_index = 8 + (left_num - 1) * obj_size
483+
obj_bytes = shm.buf[start_index:end_index].tobytes()
484+
shm.buf[0:4] = (left_num - 1).to_bytes(4, byteorder="little")
485+
shm.close()
486+
return ForkingPickler.loads(obj_bytes)
472487

473488

474489
class ReadOnlyStaticsMemoryManager:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def init_custom(self):
245245

246246
def init_dp_kv_shared(self):
247247
from lightllm.server.router.model_infer.mode_backend.dp_backend.dp_shared_kv_trans import DPKVSharedMoudle
248-
from lightllm.common.mem_manager import MemoryManager
248+
from lightllm.common.kv_cache_mem_manager import MemoryManager
249249

250250
torch.cuda.set_device(get_current_device_id())
251251

@@ -260,7 +260,7 @@ def init_dp_kv_shared(self):
260260
self.mem_managers = []
261261
for rank_idx in range(self.node_world_size):
262262
if rank_idx != self.rank_in_node:
263-
self.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, self.rank_in_node))
263+
self.mem_managers.append(MemoryManager.loads_from_shm(rank_idx))
264264
else:
265265
self.mem_managers.append(self.model.mem_manager)
266266
return

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.
114114
# 从共享内存读取所有rank的mem_manager
115115
node_world_size = args.tp // args.nnodes
116116
mem_managers: List[MemoryManager] = [
117-
MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size)
118-
for rank in range(node_world_size)
117+
MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size)
119118
]
120119

121120
task_out_queue.put("get_mem_managers_ok")

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ def _init_env(
119119
# 从共享内存读取所有rank的mem_manager
120120
node_world_size = args.tp // args.nnodes
121121
mem_managers: List[MemoryManager] = [
122-
MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size)
123-
for rank in range(node_world_size)
122+
MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size)
124123
]
125124
task_out_queue.put("get_mem_managers_ok")
126125
connect_id_to_comm: Dict[str, PyNcclCommunicator] = {}

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def _init_env(
5858
# 从共享内存读取所有rank的mem_manager
5959
node_world_size = args.tp // args.nnodes
6060
mem_managers: List[MemoryManager] = [
61-
MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size)
62-
for rank in range(node_world_size)
61+
MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size)
6362
]
6463

6564
task_out_queue.put("get_mem_managers_ok")

lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ def _init_env(
5050
# 从共享内存读取所有rank的mem_manager
5151
node_world_size = args.tp // args.nnodes
5252
mem_managers: List[MemoryManager] = [
53-
MemoryManager.loads_from_shm(rank_in_node=rank, current_rank_in_node=device_id + node_world_size)
54-
for rank in range(node_world_size)
53+
MemoryManager.loads_from_shm(rank_in_node=rank) for rank in range(node_world_size)
5554
]
5655

5756
task_out_queue.put("get_mem_managers_ok")

0 commit comments

Comments
 (0)