Skip to content

Commit 5f9a490

Browse files
committed
Fix bug for pd mem manager.
1 parent 7066b69 commit 5f9a490

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

lightllm/common/deepseek2_mem_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
3333
(1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
3434
)
3535
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
36+
self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2]
3637
return
3738

3839
def send_to_decode_node(
@@ -58,7 +59,7 @@ def send_to_decode_node(
5859
return
5960

6061
def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
61-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * len(token_indexes)
62+
move_size = self.token_dim_size * len(token_indexes)
6263
move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(
6364
1, len(token_indexes), self.head_num, self.head_dim
6465
)
@@ -82,7 +83,7 @@ def receive_from_prefill_node(
8283

8384
cur_device_index = self.kv_buffer.get_device()
8485
token_num = len(move_token_indexes)
85-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
86+
move_size = self.token_dim_size * token_num
8687
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
8788
for layer_index in range(self.layer_num):
8889
nccl_comm.recv(recive_buffer, src=0)
@@ -145,7 +146,7 @@ def _get_kv_move_data_p2p(
145146
dp_size_in_node: int,
146147
):
147148
move_token_num = len(token_indexes)
148-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
149+
move_size = self.token_dim_size * move_token_num
149150
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim)
150151
kv_trans_v2_for_p_node(
151152
input_mems=self.mem_ptrs_dict[layer_index],
@@ -184,7 +185,7 @@ def receive_from_prefill_node_p2p(
184185
token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda")
185186

186187
token_num = len(move_token_indexes)
187-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
188+
move_size = self.token_dim_size * token_num
188189
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
189190
for layer_index in range(self.layer_num):
190191
nccl_comm.recv(recive_buffer, src=0)

lightllm/common/mem_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8989
(1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
9090
)
9191
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
92+
self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1]
9293
return
9394

9495
def send_to_decode_node(
@@ -124,7 +125,7 @@ def send_to_decode_node(
124125
return
125126

126127
def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
127-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * len(token_indexes)
128+
move_size = self.token_dim_size * len(token_indexes)
128129
move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(
129130
1, len(token_indexes), 2 * self.head_num, self.head_dim
130131
)
@@ -149,7 +150,7 @@ def receive_from_prefill_node(
149150

150151
cur_device_index = self.kv_buffer.get_device()
151152
token_num = len(move_token_indexes)
152-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
153+
move_size = self.token_dim_size * token_num
153154
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
154155
for i, mem in enumerate(mem_managers):
155156
for layer_index in range(mem.layer_num):
@@ -196,7 +197,7 @@ def send_to_decode_node_p2p(
196197

197198
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
198199
move_token_num = len(token_indexes)
199-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
200+
move_size = self.token_dim_size * move_token_num
200201
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim)
201202
kv_trans(
202203
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
@@ -222,7 +223,7 @@ def receive_from_prefill_node_p2p(
222223
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
223224

224225
token_num = len(move_token_indexes)
225-
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
226+
move_size = self.token_dim_size * token_num
226227
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
227228
for i, mem in enumerate(mem_managers):
228229
for layer_index in range(mem.layer_num):

0 commit comments

Comments
 (0)