Skip to content

Commit de9b058

Browse files
committed
fix
1 parent 2359e13 commit de9b058

File tree

3 files changed

+36
-38
lines changed

3 files changed

+36
-38
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
4040
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
4141
from .multi_level_kv_cache import MultiLevelKvCacheModule
42-
from .dp_backend.dp_shared_kv_trans import init_dp_kv_shared
4342

4443

4544
class ModeBackend:
@@ -215,11 +214,12 @@ def init_model(self, kvargs):
215214
or self.args.enable_dp_prompt_cache_fetch
216215
):
217216
self.model.mem_manager.write_to_shm(req_manager=self.model.req_manager)
217+
dist.barrier(group=self.node_nccl_group)
218218

219219
self.init_custom()
220220

221221
if self.args.enable_dp_prompt_cache_fetch:
222-
init_dp_kv_shared(self)
222+
self.init_dp_kv_shared()
223223

224224
self.shm_reqs_io_buffer = ShmObjsIOBuffer()
225225
# 只会在 nixl pd 模式下才会使用,用于上传分块传输任务是否成功。
@@ -243,6 +243,28 @@ def init_model(self, kvargs):
243243
def init_custom(self):
244244
pass
245245

246+
def init_dp_kv_shared(self):
247+
from lightllm.server.router.model_infer.mode_backend.dp_backend.dp_shared_kv_trans import DPKVSharedMoudle
248+
from lightllm.common.mem_manager import MemoryManager
249+
250+
torch.cuda.set_device(get_current_device_id())
251+
252+
self.dp_kv_shared_module = DPKVSharedMoudle(
253+
max_req_num=self.args.running_max_req_size,
254+
max_req_seq_len=self.args.max_req_total_len + 8,
255+
dp_size_in_node=self.dp_size_in_node,
256+
backend=self,
257+
)
258+
259+
# Collect mem_managers from all ranks
260+
self.mem_managers = []
261+
for rank_idx in range(self.node_world_size):
262+
if rank_idx != self.rank_in_node:
263+
self.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, self.rank_in_node))
264+
else:
265+
self.mem_managers.append(self.model.mem_manager)
266+
return
267+
246268
def get_max_total_token_num(self):
247269
return self.model.mem_manager.size
248270

lightllm/server/router/model_infer/mode_backend/dp_backend/dp_shared_kv_trans.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# 该文件用于提供在数据dp并行的推理模式下,共享kv cache trans相关的功能函数模块
2+
import time
23
import numpy as np
34
import dataclasses
45
import torch
@@ -9,6 +10,8 @@
910
from lightllm.server.core.objs.shm_array import ShmArray
1011
from ...infer_batch import InferReq
1112
from lightllm.utils.dist_utils import get_current_device_id
13+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
14+
import torch.distributed as dist
1215

1316

1417
class DPKVSharedMoudle:
@@ -36,7 +39,7 @@ def fill_reqs_info(self, reqs: List[InferReq]):
3639
"""
3740
填充请求的 kv 信息到共享内存中
3841
"""
39-
self.backend.node_nccl_group.barrier()
42+
dist.barrier(group=self.backend.node_nccl_group)
4043
if self.backend.is_master_in_dp:
4144
self.shared_req_infos.arr[0 : len(reqs), self.dp_rank_in_node, self._KV_LEN_INDEX] = [
4245
req.cur_kv_len for req in reqs
@@ -54,9 +57,7 @@ def build_shared_kv_trans_tasks(
5457
"""
5558
构建共享kv交换信息
5659
"""
57-
from lightllm.server.router.model_infer.infer_batch import g_infer_context
58-
59-
self.backend.node_nccl_group.barrier()
60+
dist.barrier(group=self.backend.node_nccl_group)
6061

6162
trans_tasks: List[TransTask] = []
6263
rank_max_radix_cache_lens = np.max(
@@ -96,22 +97,20 @@ def build_shared_kv_trans_tasks(
9697
def kv_trans(self, trans_tasks: List["TransTask"]):
9798
from lightllm.server.router.model_infer.infer_batch import g_infer_context
9899

99-
self.backend.node_nccl_group.barrier()
100100
# kv 传输
101101
if len(trans_tasks) > 0:
102102
max_kv_len_mem_indexes = []
103103
max_kv_len_dp_ranks = []
104104
mem_indexes = []
105105

106106
for i, trans_task in enumerate(trans_tasks):
107-
max_kv_len_mem_indexes.extend(trans_task.max_kv_len_mem_indexes)
107+
max_kv_len_mem_indexes.append(trans_task.max_kv_len_mem_indexes)
108108
max_kv_len_dp_ranks.extend([trans_task.max_kv_len_dp_rank] * len(trans_task.max_kv_len_mem_indexes))
109-
mem_indexes.extend(trans_task.mem_indexes)
109+
mem_indexes.append(trans_task.mem_indexes)
110110

111-
max_kv_len_mem_indexes_tensor = torch.tensor(max_kv_len_mem_indexes, dtype=torch.int64, device="cuda")
111+
max_kv_len_mem_indexes_tensor = torch.cat(max_kv_len_mem_indexes).to(dtype=torch.int64, device="cuda")
112112
max_kv_len_dp_ranks_tensor = torch.tensor(max_kv_len_dp_ranks, dtype=torch.int32, device="cuda")
113-
mem_indexes_tensor = torch.tensor(mem_indexes, dtype=torch.int64, device="cuda")
114-
113+
mem_indexes_tensor = torch.cat(mem_indexes).to(dtype=torch.int64, device="cuda")
115114
self.backend.model.mem_manager.copy_kv_from_other_dp_ranks(
116115
mem_managers=self.backend.mem_managers,
117116
move_token_indexes=max_kv_len_mem_indexes_tensor,
@@ -122,7 +121,6 @@ def kv_trans(self, trans_tasks: List["TransTask"]):
122121
)
123122
self.backend.logger.info(f"dp_i {self.dp_rank_in_node} transfer kv tokens num: {len(mem_indexes_tensor)}")
124123

125-
self.backend.node_nccl_group.barrier()
126124
for trans_task in trans_tasks:
127125
g_infer_context.req_manager.req_to_token_indexs[
128126
trans_task.req.req_idx,
@@ -131,7 +129,6 @@ def kv_trans(self, trans_tasks: List["TransTask"]):
131129
trans_task.req.cur_kv_len += len(trans_task.mem_indexes)
132130
if self.backend.is_master_in_dp:
133131
trans_task.req.shm_req.shm_cur_kv_len = trans_task.req.cur_kv_len
134-
self.backend.node_nccl_group.barrier()
135132

136133

137134
@dataclasses.dataclass
@@ -141,24 +138,3 @@ class TransTask:
141138
max_kv_len_dp_rank: int
142139
max_kv_len_mem_manager_index: int
143140
max_kv_len_mem_indexes: torch.Tensor
144-
145-
146-
def init_dp_kv_shared(backend):
147-
torch.cuda.set_device(get_current_device_id())
148-
149-
backend.dp_kv_shared_moudle = DPKVSharedMoudle(
150-
max_req_num=backend.args.running_max_req_size,
151-
max_req_seq_len=backend.args.max_req_total_len + 8,
152-
dp_size_in_node=backend.dp_size_in_node,
153-
backend=backend,
154-
)
155-
backend.node_nccl_group.barrier()
156-
157-
# Collect mem_managers from all ranks
158-
backend.mem_managers = []
159-
for rank_idx in range(backend.node_world_size):
160-
if rank_idx != backend.rank_in_node:
161-
backend.mem_managers.append(MemoryManager.loads_from_shm(rank_idx, backend.rank_in_node))
162-
else:
163-
backend.mem_managers.append(backend.model.mem_manager)
164-
return

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def _init_reqs(self, reqs: List[Tuple]):
7777

7878
infer_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=True)
7979
req_dp_ranks = [req[3] for req in reqs]
80-
self.dp_kv_shared_moudle.fill_reqs_info(reqs=infer_reqs)
81-
trans_taskes = self.dp_kv_shared_moudle.build_shared_kv_trans_tasks(reqs=infer_reqs, req_dp_ranks=req_dp_ranks)
82-
self.dp_kv_shared_moudle.kv_trans(trans_tasks=trans_taskes)
80+
self.dp_kv_shared_module.fill_reqs_info(reqs=infer_reqs)
81+
trans_taskes = self.dp_kv_shared_module.build_shared_kv_trans_tasks(reqs=infer_reqs, req_dp_ranks=req_dp_ranks)
82+
self.dp_kv_shared_module.kv_trans(trans_tasks=trans_taskes)
8383

8484
g_infer_context._filter(finished_request_ids=[req[0] for req in other_dp_reqs])
8585
g_infer_state_lock.release()

0 commit comments

Comments
 (0)