Skip to content

Commit 2d650fd

Browse files
author
wangzaijun
committed
fix
1 parent 8aa2632 commit 2d650fd

File tree

4 files changed

+171
-180
lines changed

4 files changed

+171
-180
lines changed

lightllm/common/mem_manager.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from lightllm.utils.dist_utils import get_current_device_id
1717
from lightllm.utils.config_utils import get_num_key_value_heads
1818
from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io
19+
from lightllm.utils.device_utils import kv_trans_use_p2p
1920
from lightllm.utils.shm_utils import create_or_link_shm
2021
from multiprocessing.reduction import ForkingPickler
2122

@@ -432,13 +433,22 @@ def copy_kv_from_other_dp_ranks(
432433
rank_in_dp=rank_in_dp,
433434
)
434435

435-
def write_to_shm(self):
436+
def write_to_shm(self, req_manager):
436437
"""
437438
将 mem manager 写入到 shm中,方便pd分离等特性直接从中读取,不依赖进程间队列。
438439
"""
439-
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor
440+
if kv_trans_use_p2p():
441+
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import reduce_tensor
440442

441-
mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__
443+
mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__
444+
445+
from lightllm.common.req_manager import ReqManager
446+
447+
req_manager: ReqManager = req_manager
448+
449+
# 这个地方是一个不太优雅的设计,但是暂时这么做,可以让dp shared kv swap模块直接访问 req_manager 中的 req_to_token_indexs
450+
# 避免过多无用的数据复制和传输开销。
451+
self.req_to_token_indexs: torch.Tensor = req_manager.req_to_token_indexs
442452

443453
shm_name = f"{get_unique_server_name()}_mem_manager_{get_current_rank_in_node()}"
444454
obj_bytes = ForkingPickler.dumps(self).tobytes()

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def init_model(self, kvargs):
224224
self.args.run_mode in ["nixl_prefill", "nixl_decode", "prefill", "decode"]
225225
or self.args.enable_dp_prompt_cache_fetch
226226
):
227-
self.model.mem_manager.write_to_shm()
227+
self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager)
228228

229229
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
230230
# 可以降低 cpu overhead,大幅提升gpu得使用率。
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# 该文件用于提供在数据dp并行的推理模式下,共享kv cache trans相关的功能函数模块
2+
import numpy as np
3+
import dataclasses
4+
import torch
5+
from typing import List
6+
from lightllm.common.mem_manager import MemoryManager
7+
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
8+
from lightllm.utils.dist_utils import get_dp_rank_in_node
9+
from lightllm.server.core.objs.shm_array import ShmArray
10+
from ...infer_batch import InferReq
11+
12+
13+
class DPKVSharedMoudle:
14+
_KV_LEN_INDEX = 0
15+
_REQ_IDX_INDEX = 1
16+
17+
def __init__(self, max_req_num: int, max_req_seq_len: int, dp_size_in_node: int, backend):
18+
from .impl import DPChunkedPrefillBackend
19+
20+
self.backend: DPChunkedPrefillBackend = backend
21+
self.max_req_num = max_req_num
22+
self.max_req_seq_len = max_req_seq_len
23+
24+
# 0 代表 kv_len, 1 代表 radix_cache_len
25+
self.shared_req_infos = ShmArray(
26+
name=f"{get_unique_server_name()}_dp_shared_req_infos",
27+
shape=(self.max_req_num, dp_size_in_node, 2),
28+
dtype=np.int64,
29+
)
30+
self.shared_req_infos.create_shm()
31+
self.dp_rank_in_node = get_dp_rank_in_node()
32+
assert get_env_start_args().diverse_mode is False
33+
34+
def fill_reqs_info(
35+
self,
36+
reqs: List[InferReq],
37+
req_dp_ranks: List[int],
38+
):
39+
"""
40+
填充请求的 kv 信息到共享内存中
41+
"""
42+
self.backend.node_nccl_group.barrier()
43+
self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [
44+
req.cur_kv_len for req in reqs
45+
]
46+
self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._REQ_IDX_INDEX] = [
47+
req.req_idx for req in reqs
48+
]
49+
return
50+
51+
def build_shared_kv_trans_tasks(
52+
self,
53+
reqs: List[InferReq],
54+
req_dp_ranks: List[int],
55+
) -> List["TransTask"]:
56+
"""
57+
构建共享kv交换信息
58+
"""
59+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
60+
61+
self.backend.node_nccl_group.barrier()
62+
63+
trans_tasks: List[TransTask] = []
64+
rank_max_radix_cache_lens = np.max(
65+
self.shared_req_infos.arr[0 : len(reqs), :, self._KV_LEN_INDEX], axis=1, keepdims=False
66+
)
67+
# 如果发现自己是dp_rank 最小, radix_cache_len 最长的请求,则将数据写入到共享内存中。
68+
for req_index, req, max_req_radix_cache_len, req_dp_rank in zip(
69+
list(range(len(reqs))), reqs, rank_max_radix_cache_lens, req_dp_ranks
70+
):
71+
# 当前请求是本 dp_rank 负责的
72+
is_current_dp_handle = req_dp_rank == self.dp_rank_in_node
73+
trans_size = max_req_radix_cache_len - req.cur_kv_len
74+
75+
if is_current_dp_handle and trans_size > 0 and g_infer_context.get_can_alloc_token_num() > trans_size:
76+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(trans_size)
77+
mem_indexes = self.backend.model.mem_manager.alloc(trans_size)
78+
max_kv_len_dp_rank = self.shared_req_infos.arr[req_index, :, self._KV_LEN_INDEX].argmax()
79+
max_kv_len_req_idx = int(self.shared_req_infos.arr[req_index, max_kv_len_dp_rank, self._REQ_IDX_INDEX])
80+
max_kv_len_mem_manager_index = (
81+
max_kv_len_dp_rank * self.backend.dp_world_size + self.backend.dp_rank_in_node
82+
)
83+
max_kv_len_mem_manager: MemoryManager = self.backend.mem_managers[max_kv_len_mem_manager_index]
84+
max_kv_len_mem_indexes = max_kv_len_mem_manager.req_to_token_indexs[
85+
max_kv_len_req_idx, req.cur_kv_len : max_req_radix_cache_len
86+
]
87+
trans_tasks.append(
88+
TransTask(
89+
req=req,
90+
mem_indexes=mem_indexes,
91+
max_kv_len_dp_rank=int(max_kv_len_dp_rank),
92+
max_kv_len_mem_manager_index=int(max_kv_len_mem_manager_index),
93+
max_kv_len_mem_indexes=max_kv_len_mem_indexes,
94+
)
95+
)
96+
97+
return trans_tasks
98+
99+
def kv_trans(self, trans_tasks: List["TransTask"]):
100+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
101+
102+
self.backend.node_nccl_group.barrier()
103+
# kv 传输
104+
105+
# move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
106+
# token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda")
107+
108+
# self.model.mem_manager.copy_kv_from_other_dp_ranks(
109+
# mem_managers=self.mem_managers,
110+
# move_token_indexes=move_token_indexes,
111+
# token_dp_indexes=token_dp_indexes,
112+
# mem_indexes=mem_indexes,
113+
# dp_size_in_node=self.dp_size_in_node,
114+
# rank_in_dp=self.rank_in_dp,
115+
# )
116+
# self.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {alloc_size}")
117+
118+
self.backend.node_nccl_group.barrier()
119+
for trans_task in trans_tasks:
120+
g_infer_context.req_manager.req_to_token_indexs[
121+
trans_task.req.req_idx,
122+
trans_task.req.cur_kv_len : (trans_task.req.cur_kv_len + len(trans_task.mem_indexes)),
123+
] = trans_task.mem_indexes
124+
trans_task.req.cur_kv_len += len(trans_task.mem_indexes)
125+
if self.backend.is_master_in_dp:
126+
trans_task.req.shm_req.shm_cur_kv_len = trans_task.req.cur_kv_len
127+
self.backend.node_nccl_group.barrier()
128+
129+
130+
@dataclasses
131+
class TransTask:
132+
req: InferReq
133+
mem_indexes: torch.Tensor
134+
max_kv_len_dp_rank: int
135+
max_kv_len_mem_manager_index: int
136+
max_kv_len_mem_indexes: torch.Tensor

0 commit comments

Comments
 (0)