Skip to content

Commit 3d3afe7

Browse files
authored
pd mode. batch kv trans. (#667)
1 parent 5cd3845 commit 3d3afe7

File tree

20 files changed

+403
-290
lines changed

20 files changed

+403
-290
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _init_mem_manager(self):
158158
def _init_kv_move_buffer(self):
159159
# p d 分离的推理模式下才需要做这一步初始化
160160
if self.run_mode in ["prefill", "decode"]:
161-
self.mem_manager.alloc_kv_move_buffer(self.max_seq_length)
161+
self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size)
162162

163163
def _check_mem_size(self):
164164
self.max_total_token_num = self.mem_manager.size

lightllm/common/deepseek2_mem_manager.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import os
33

4+
from lightllm.server.pd_io_struct import KVMoveTask
45
from .mem_manager import MemoryManager
56
from typing import List
67

@@ -24,18 +25,22 @@ def alloc_kv_move_buffer(self, max_req_total_len):
2425
return
2526

2627
def send_to_decode_node(
27-
self, token_indexes: List[int], mem_managers: List["Deepseek2MemoryManager"], dp_size: int, dp_index: int
28+
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size: int
2829
):
2930
assert dp_size == 1
30-
assert dp_index == 0
3131

3232
# 先将数据发送到指定的一张卡上的buffer,再发送。
3333
import torch.distributed as dist
3434

35+
move_token_indexes = []
36+
for task in move_tasks:
37+
if task.move_kv_len != 0:
38+
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
39+
3540
cur_device_index = self.kv_buffer.get_device()
3641
cur_mem = mem_managers[cur_device_index]
3742
for layer_index in range(cur_mem.layer_num):
38-
move_buffer = cur_mem._get_kv_move_data(token_indexes, layer_index)
43+
move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index)
3944
dist.send(move_buffer, dst=1)
4045
return
4146

@@ -48,29 +53,33 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
4853
return move_buffer
4954

5055
def receive_from_prefill_node(
51-
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
56+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
5257
):
5358
assert dp_size == 1
54-
assert dp_index == 0
5559

5660
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
5761
import torch.distributed as dist
5862

63+
move_token_indexes = []
64+
for task in move_tasks:
65+
if task.move_kv_len != 0:
66+
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
67+
5968
cur_device_index = self.kv_buffer.get_device()
60-
token_num = len(token_indexes)
69+
token_num = len(move_token_indexes)
6170
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
6271
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
6372
for layer_index in range(self.layer_num):
6473
dist.recv(recive_buffer, src=0)
6574
for i, mem in enumerate(mem_managers):
6675
if i == cur_device_index:
67-
mem._write_kv_move_data(token_indexes, recive_buffer, layer_index)
76+
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
6877
else:
6978
new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape)
7079
from torch.cuda import comm
7180

7281
comm.broadcast(recive_buffer, out=[new_recive_buffer])
73-
mem._write_kv_move_data(token_indexes, new_recive_buffer, layer_index)
82+
mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index)
7483
return
7584

7685
def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):

lightllm/common/mem_manager.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.distributed as dist
55
from typing import List
6+
from lightllm.server.pd_io_struct import KVMoveTask
67
from lightllm.utils.log_utils import init_logger
78
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
89
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
@@ -79,25 +80,27 @@ def alloc_kv_move_buffer(self, max_req_total_len):
7980
)
8081
return
8182

82-
def send_to_decode_node(
83-
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
84-
):
83+
def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
8584
"""
86-
dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
85+
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
8786
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
8887
被真正使用
8988
"""
9089
assert dp_size == 1
91-
assert dp_index == 0
9290

9391
# 先将数据发送到指定的一张卡上的buffer,再发送。
9492
import torch.distributed as dist
9593

94+
move_token_indexes = []
95+
for task in move_tasks:
96+
if task.move_kv_len != 0:
97+
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
98+
9699
cur_device_index = self.kv_buffer.get_device()
97100
cur_mem = mem_managers[cur_device_index]
98101
for i, mem in enumerate(mem_managers):
99102
for layer_index in range(mem.layer_num):
100-
move_buffer = mem._get_kv_move_data(token_indexes, layer_index)
103+
move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index)
101104
if i == cur_device_index:
102105
dist.send(move_buffer, dst=1)
103106
else:
@@ -118,34 +121,38 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
118121
return move_buffer
119122

120123
def receive_from_prefill_node(
121-
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
124+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
122125
):
123126
"""
124-
dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
125-
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
127+
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
128+
普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
126129
被真正使用
127130
"""
128131
assert dp_size == 1
129-
assert dp_index == 0
130132

131133
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
132134
import torch.distributed as dist
133135

136+
move_token_indexes = []
137+
for task in move_tasks:
138+
if task.move_kv_len != 0:
139+
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
140+
134141
cur_device_index = self.kv_buffer.get_device()
135-
token_num = len(token_indexes)
142+
token_num = len(move_token_indexes)
136143
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
137144
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
138145
for i, mem in enumerate(mem_managers):
139146
for layer_index in range(mem.layer_num):
140147
dist.recv(recive_buffer, src=0)
141148
if i == cur_device_index:
142-
mem._write_kv_move_data(token_indexes, recive_buffer, layer_index)
149+
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
143150
else:
144151
new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape)
145152
from torch.cuda import comm
146153

147154
comm.broadcast(recive_buffer, out=[new_recive_buffer])
148-
mem._write_kv_move_data(token_indexes, new_recive_buffer, layer_index)
155+
mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index)
149156
return
150157

151158
def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):

lightllm/server/router/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ async def wait_to_model_ready(self):
153153

154154
if self.args.run_mode == "prefill":
155155
# 启动 prefill kv move 管理进程
156-
from lightllm.server.router.model_infer.mode_backend.continues_batch.prefill_node_impl import (
156+
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.prefill_node_impl import (
157157
start_prefill_kv_move_manager_process,
158158
)
159159

160160
start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues)
161161

162162
if self.args.run_mode == "decode":
163163
# 启动 decode kv move 管理进程
164-
from lightllm.server.router.model_infer.mode_backend.continues_batch.decode_node_impl import (
164+
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import (
165165
start_decode_kv_move_manager_process,
166166
)
167167

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
from .continues_batch.impl_for_token_healing import TokenHealingBackend
88
from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend
99
from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
10-
from .continues_batch.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode
11-
from .continues_batch.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode
10+
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode
11+
from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/__init__.PY

Whitespace-only changes.

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/__init__.py renamed to lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/__init__.py

File renamed without changes.

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py renamed to lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from lightllm.server.io_struct import ReqRunStatus, FinishStatus
1212
from lightllm.server.pd_io_struct import UpKVStatus
1313
from lightllm.utils.log_utils import init_logger
14-
from ..pre_process import prepare_prefill_inputs, prepare_decode_inputs
15-
from ..post_process import sample
14+
from ...pre_process import prepare_prefill_inputs, prepare_decode_inputs
15+
from ...post_process import sample
1616
from .up_status import UpStatusManager
1717
from rpyc.utils.server import ThreadedServer
1818
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, g_router_lock

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_infer_rpyc.py renamed to lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py

File renamed without changes.

0 commit comments

Comments
 (0)