Skip to content

Commit 5bd5ffd

Browse files
author
Weichao Luo
committed
single kv transfer process for pd.
1 parent c16e7b8 commit 5bd5ffd

File tree

9 files changed

+1011
-220
lines changed

9 files changed

+1011
-220
lines changed

lightllm/common/deepseek2_mem_manager.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
99
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
10+
from lightllm.distributed.pynccl import PyNcclCommunicator
1011

1112
logger = init_logger(__name__)
1213

@@ -41,7 +42,8 @@ def alloc_kv_move_buffer(self, max_req_total_len):
4142
return
4243

4344
def send_to_decode_node(
44-
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
45+
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int,
46+
nccl_comm: PyNcclCommunicator
4547
):
4648
assert dp_size_in_node == 1
4749

@@ -55,7 +57,7 @@ def send_to_decode_node(
5557
cur_mem = mem_managers[cur_device_index]
5658
for layer_index in range(cur_mem.layer_num):
5759
move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index)
58-
dist.send(move_buffer, dst=1)
60+
nccl_comm.send(move_buffer, dst=1)
5961
return
6062

6163
def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
@@ -67,7 +69,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
6769
return move_buffer
6870

6971
def receive_from_prefill_node(
70-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
72+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
73+
nccl_comm: PyNcclCommunicator
7174
):
7275
assert dp_size_in_node == 1
7376

@@ -82,7 +85,7 @@ def receive_from_prefill_node(
8285
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
8386
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
8487
for layer_index in range(self.layer_num):
85-
dist.recv(recive_buffer, src=0)
88+
nccl_comm.recv(recive_buffer, src=0)
8689
for i, mem in enumerate(mem_managers):
8790
if i == cur_device_index:
8891
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
@@ -99,7 +102,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
99102
return
100103

101104
def send_to_decode_node_p2p(
102-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
105+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
106+
nccl_comm: PyNcclCommunicator
103107
):
104108
"""
105109
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -126,7 +130,7 @@ def send_to_decode_node_p2p(
126130
move_buffer = self._get_kv_move_data_p2p(
127131
move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node
128132
)
129-
dist.send(move_buffer, dst=1)
133+
nccl_comm.send(move_buffer, dst=1)
130134
return
131135

132136
def _get_kv_move_data_p2p(
@@ -151,7 +155,8 @@ def _get_kv_move_data_p2p(
151155
return move_buffer
152156

153157
def receive_from_prefill_node_p2p(
154-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
158+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
159+
nccl_comm: PyNcclCommunicator
155160
):
156161
if not hasattr(self, "mem_ptrs_dict"):
157162
self.mem_ptrs_dict = {}
@@ -176,7 +181,7 @@ def receive_from_prefill_node_p2p(
176181
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
177182
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
178183
for layer_index in range(self.layer_num):
179-
dist.recv(recive_buffer, src=0)
184+
nccl_comm.recv(recive_buffer, src=0)
180185
self._write_kv_move_data_p2p(
181186
move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node
182187
)

lightllm/common/mem_manager.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
1111
from lightllm.utils.dist_utils import get_current_rank_in_node
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
13+
from lightllm.distributed.pynccl import PyNcclCommunicator
1314

1415

1516
logger = init_logger(__name__)
@@ -86,7 +87,8 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8687
return
8788

8889
def send_to_decode_node(
89-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
90+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
91+
nccl_comm: PyNcclCommunicator
9092
):
9193
assert dp_size_in_node == 1
9294

@@ -103,14 +105,14 @@ def send_to_decode_node(
103105
for layer_index in range(mem.layer_num):
104106
move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index)
105107
if i == cur_device_index:
106-
dist.send(move_buffer, dst=1)
108+
nccl_comm.send(move_buffer, dst=1)
107109
else:
108110
move_size = move_buffer.numel()
109111
new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape)
110112
from torch.cuda import comm
111113

112114
comm.broadcast(move_buffer, out=[new_move_buffer])
113-
dist.send(new_move_buffer, dst=1)
115+
nccl_comm.send(new_move_buffer, dst=1)
114116
return
115117

116118
def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
@@ -122,7 +124,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
122124
return move_buffer
123125

124126
def receive_from_prefill_node(
125-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
127+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
128+
nccl_comm: PyNcclCommunicator,
126129
):
127130
assert dp_size_in_node == 1
128131

@@ -139,7 +142,7 @@ def receive_from_prefill_node(
139142
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
140143
for i, mem in enumerate(mem_managers):
141144
for layer_index in range(mem.layer_num):
142-
dist.recv(recive_buffer, src=0)
145+
nccl_comm.recv(recive_buffer, src=0)
143146
if i == cur_device_index:
144147
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
145148
else:
@@ -155,7 +158,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
155158
return
156159

157160
def send_to_decode_node_p2p(
158-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
161+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
162+
nccl_comm: PyNcclCommunicator
159163
):
160164
"""
161165
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -173,7 +177,7 @@ def send_to_decode_node_p2p(
173177
for i, mem in enumerate(mem_managers):
174178
for layer_index in range(mem.layer_num):
175179
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
176-
dist.send(move_buffer, dst=1)
180+
nccl_comm.send(move_buffer, dst=1)
177181
return
178182

179183
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
@@ -186,7 +190,8 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
186190
return move_buffer
187191

188192
def receive_from_prefill_node_p2p(
189-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
193+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
194+
nccl_comm: PyNcclCommunicator,
190195
):
191196
assert dp_size_in_node == 1
192197

@@ -204,7 +209,7 @@ def receive_from_prefill_node_p2p(
204209
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
205210
for i, mem in enumerate(mem_managers):
206211
for layer_index in range(mem.layer_num):
207-
dist.recv(recive_buffer, src=0)
212+
nccl_comm.recv(recive_buffer, src=0)
208213
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
209214
return
210215

0 commit comments

Comments
 (0)