Skip to content

Commit 0be3e0e

Browse files
author
wanzihao
committed
merge main
2 parents c243478 + a2a6830 commit 0be3e0e

25 files changed

+1463
-835
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _init_mem_manager(self):
158158
def _init_kv_move_buffer(self):
159159
# p d 分离的推理模式下才需要做这一步初始化
160160
if self.run_mode in ["prefill", "decode"]:
161-
self.mem_manager.alloc_kv_move_buffer(self.max_seq_length)
161+
self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size)
162162

163163
def _check_mem_size(self):
164164
self.max_total_token_num = self.mem_manager.size

lightllm/common/deepseek2_mem_manager.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import os
33

4+
from lightllm.server.pd_io_struct import KVMoveTask
45
from .mem_manager import MemoryManager
56
from typing import List
67
from lightllm.utils.log_utils import init_logger
@@ -32,18 +33,22 @@ def alloc_kv_move_buffer(self, max_req_total_len):
3233
return
3334

3435
def send_to_decode_node(
35-
self, token_indexes: List[int], mem_managers: List["Deepseek2MemoryManager"], dp_size: int, dp_index: int
36+
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size: int
3637
):
3738
assert dp_size == 1
38-
assert dp_index == 0
3939

4040
# 先将数据发送到指定的一张卡上的buffer,再发送。
4141
import torch.distributed as dist
4242

43+
move_token_indexes = []
44+
for task in move_tasks:
45+
if task.move_kv_len != 0:
46+
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
47+
4348
cur_device_index = self.kv_buffer.get_device()
4449
cur_mem = mem_managers[cur_device_index]
4550
for layer_index in range(cur_mem.layer_num):
46-
move_buffer = cur_mem._get_kv_move_data(token_indexes, layer_index)
51+
move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index)
4752
dist.send(move_buffer, dst=1)
4853
return
4954

@@ -56,29 +61,33 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
5661
return move_buffer
5762

5863
def receive_from_prefill_node(
59-
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
64+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
6065
):
6166
assert dp_size == 1
62-
assert dp_index == 0
6367

6468
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
6569
import torch.distributed as dist
6670

71+
move_token_indexes = []
72+
for task in move_tasks:
73+
if task.move_kv_len != 0:
74+
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
75+
6776
cur_device_index = self.kv_buffer.get_device()
68-
token_num = len(token_indexes)
77+
token_num = len(move_token_indexes)
6978
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
7079
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
7180
for layer_index in range(self.layer_num):
7281
dist.recv(recive_buffer, src=0)
7382
for i, mem in enumerate(mem_managers):
7483
if i == cur_device_index:
75-
mem._write_kv_move_data(token_indexes, recive_buffer, layer_index)
84+
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
7685
else:
7786
new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape)
7887
from torch.cuda import comm
7988

8089
comm.broadcast(recive_buffer, out=[new_recive_buffer])
81-
mem._write_kv_move_data(token_indexes, new_recive_buffer, layer_index)
90+
mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index)
8291
return
8392

8493
def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):

lightllm/common/mem_manager.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.distributed as dist
55
from typing import List
6+
from lightllm.server.pd_io_struct import KVMoveTask
67
from lightllm.utils.log_utils import init_logger
78
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
89
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
@@ -79,25 +80,27 @@ def alloc_kv_move_buffer(self, max_req_total_len):
7980
)
8081
return
8182

82-
def send_to_decode_node(
83-
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
84-
):
83+
def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
8584
"""
86-
dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
85+
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
8786
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
8887
被真正使用
8988
"""
9089
assert dp_size == 1
91-
assert dp_index == 0
9290

9391
# 先将数据发送到指定的一张卡上的buffer,再发送。
9492
import torch.distributed as dist
9593

94+
move_token_indexes = []
95+
for task in move_tasks:
96+
if task.move_kv_len != 0:
97+
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
98+
9699
cur_device_index = self.kv_buffer.get_device()
97100
cur_mem = mem_managers[cur_device_index]
98101
for i, mem in enumerate(mem_managers):
99102
for layer_index in range(mem.layer_num):
100-
move_buffer = mem._get_kv_move_data(token_indexes, layer_index)
103+
move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index)
101104
if i == cur_device_index:
102105
dist.send(move_buffer, dst=1)
103106
else:
@@ -118,34 +121,38 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
118121
return move_buffer
119122

120123
def receive_from_prefill_node(
121-
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
124+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
122125
):
123126
"""
124-
dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
125-
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
127+
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
128+
普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
126129
被真正使用
127130
"""
128131
assert dp_size == 1
129-
assert dp_index == 0
130132

131133
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
132134
import torch.distributed as dist
133135

136+
move_token_indexes = []
137+
for task in move_tasks:
138+
if task.move_kv_len != 0:
139+
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
140+
134141
cur_device_index = self.kv_buffer.get_device()
135-
token_num = len(token_indexes)
142+
token_num = len(move_token_indexes)
136143
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
137144
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
138145
for i, mem in enumerate(mem_managers):
139146
for layer_index in range(mem.layer_num):
140147
dist.recv(recive_buffer, src=0)
141148
if i == cur_device_index:
142-
mem._write_kv_move_data(token_indexes, recive_buffer, layer_index)
149+
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
143150
else:
144151
new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape)
145152
from torch.cuda import comm
146153

147154
comm.broadcast(recive_buffer, out=[new_recive_buffer])
148-
mem._write_kv_move_data(token_indexes, new_recive_buffer, layer_index)
155+
mem._write_kv_move_data(move_token_indexes, new_recive_buffer, layer_index)
149156
return
150157

151158
def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):

0 commit comments

Comments
 (0)