Skip to content

Commit b1938d0

Browse files
committed
layer into triton op
1 parent 78892b8 commit b1938d0

File tree

5 files changed

+93
-35
lines changed

5 files changed

+93
-35
lines changed

lightllm/common/kv_trans_kernel/kv_trans_v2.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,16 @@ def _kv_trans_for_dp_kernel(
199199
input_stride_0,
200200
input_stride_1,
201201
input_stride_2,
202+
input_stride_3,
202203
input_token_idx_ptr,
203204
input_token_dp_index_ptr,
204205
output_ptr,
205206
output_stride_0,
206207
output_stride_1,
207208
output_stride_2,
209+
output_stride_3,
208210
output_token_idx_ptr,
211+
layer_num: tl.constexpr,
209212
token_num: int,
210213
head_num: int,
211214
head_dim: int,
@@ -229,11 +232,20 @@ def _kv_trans_for_dp_kernel(
229232
mem_index = RANK_IN_DP + dp_index * CARD_NUM_PER_D
230233
input_token_idx = tl.load(input_token_idx_ptr + tid)
231234
output_token_idx = tl.load(output_token_idx_ptr + tid)
232-
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
233-
cur_offs = block_idx * BLOCK_SIZE + offs
234-
input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
235-
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
236-
tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim)
235+
236+
input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
237+
for layer_idx in tl.range(0, layer_num, 1):
238+
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
239+
cur_offs = block_idx * BLOCK_SIZE + offs
240+
in_datas = tl.load(
241+
input_ptr + input_stride_0 * layer_idx + input_stride_1 * input_token_idx + cur_offs,
242+
mask=cur_offs < head_num_dim,
243+
)
244+
tl.store(
245+
output_ptr + output_stride_0 * layer_idx + output_stride_1 * output_token_idx + cur_offs,
246+
in_datas,
247+
mask=cur_offs < head_num_dim,
248+
)
237249

238250
tid += grid_count
239251

@@ -250,19 +262,19 @@ def kv_trans_for_dp(
250262
rank_in_dp: int,
251263
):
252264
"""
253-
input_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
265+
input_mems 是一个 torch.uint64 的tensor, shape为(layer_num, mem_num),其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
254266
"""
255267
assert input_mems.is_contiguous()
256268
assert output.is_contiguous()
257269
assert len(input_mems.shape) == 1
258-
assert len(output.shape) == 3
270+
assert len(output.shape) == 4
259271
assert len(input_idx) == len(output_idx)
260272
assert len(output_idx) == len(input_dp_idx)
261273
assert len(input_mems) % dp_size_in_node == 0
262274

263275
card_num_per_d = len(input_mems) // dp_size_in_node
264276

265-
_, head_num, head_dim = output.shape
277+
layer_num, _, head_num, head_dim = output.shape
266278
token_num = len(output_idx)
267279
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
268280
grid_count = 20
@@ -278,6 +290,7 @@ def kv_trans_for_dp(
278290
output,
279291
*output.stride(),
280292
output_idx,
293+
layer_num=layer_num,
281294
token_num=token_num,
282295
head_num=head_num,
283296
head_dim=head_dim,

lightllm/common/mem_manager.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -414,25 +414,23 @@ def copy_kv_from_other_dp_ranks(
414414
dp_size_in_node: int,
415415
rank_in_dp: int,
416416
):
417-
if not hasattr(self, "mem_ptrs_dict"):
418-
self.mem_ptrs_dict = {}
419-
for layer_index in range(self.layer_num):
420-
mems_ptr = []
421-
for i in range(0, len(mem_managers)):
422-
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
423-
mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
424-
self.mem_ptrs_dict[layer_index] = mems_ptr
425-
426-
for layer_index in range(self.layer_num):
427-
kv_trans_for_dp(
428-
input_mems=self.mem_ptrs_dict[layer_index],
429-
input_idx=move_token_indexes,
430-
input_dp_idx=token_dp_indexes,
431-
output=self.kv_buffer[layer_index],
432-
output_idx=mem_indexes,
433-
dp_size_in_node=dp_size_in_node,
434-
rank_in_dp=rank_in_dp,
435-
)
417+
if not hasattr(self, "mem_ptrs_tensor"):
418+
# 构建一个2D tensor,shape为(layer_num, mem_num)
419+
mems_ptr_list = []
420+
for i in range(0, len(mem_managers)):
421+
mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr())
422+
self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cuda")
423+
424+
# 一次性传输所有层
425+
kv_trans_for_dp(
426+
input_mems=self.mem_ptrs_tensor,
427+
input_idx=move_token_indexes,
428+
input_dp_idx=token_dp_indexes,
429+
output=self.kv_buffer,
430+
output_idx=mem_indexes,
431+
dp_size_in_node=dp_size_in_node,
432+
rank_in_dp=rank_in_dp,
433+
)
436434

437435
def create_shm(self):
438436
obj_bytes = ForkingPickler.dumps(self)
@@ -449,7 +447,9 @@ def from_shm(rank_in_node):
449447
f"{get_unique_server_name()}_mem_manager_{rank_in_node}", LIGHTLLM_MEM_MANAGER_SHM_SIZE
450448
)
451449
bytes_len = int.from_bytes(shm.buf[0:4], "little")
452-
return ForkingPickler.loads(shm.buf[4 : 4 + bytes_len])
450+
obj_bytes = shm.buf[4 : 4 + bytes_len].tobytes()
451+
shm.close()
452+
return ForkingPickler.loads(obj_bytes)
453453

454454

455455
class ReadOnlyStaticsMemoryManager:

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from lightllm.common.mem_manager import MemoryManager
3030
import torch.multiprocessing as mp
3131

32-
min_trans_token_num = os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", 512)
33-
dp_kv_transfer_req_num = os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", 16)
32+
min_trans_token_num = int(os.getenv("LIGHTLLM_MIN_TRANS_TOKEN_NUM", "512"))
33+
dp_kv_transfer_req_num = int(os.getenv("LIGHTLLM_DP_KV_TRANSFER_REQ_NUM", "16"))
3434

3535

3636
class DPChunkedPrefillBackend(ModeBackend):
@@ -167,7 +167,6 @@ def _fetch_dp_prompt_cache(
167167
if sampling_param.disable_prompt_cache:
168168
continue
169169
shm_req.link_prompt_ids_shm_array()
170-
shm_req.link_kv_indexes_shm_array()
171170

172171
kv_len, value_tensor = self._match_radix_cache(shm_req)
173172
with g_infer_context.shm_req_manager.get_req_lock_by_index(shm_req.index_in_shm_mem):
@@ -210,7 +209,7 @@ def _fetch_dp_prompt_cache(
210209
def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]):
211210
other_shm_reqs = []
212211
for match, index in other_match:
213-
shm_req, kv_len, value_tensor = match
212+
shm_req, kv_len, value_tensor, _ = match
214213
trans_len = kv_len - shm_req.dp_origin_kv_len
215214
if shm_req.dp_max_kv_rank == self.dp_rank_in_node:
216215
self.shared_kv_indexes.arr[index, 0:trans_len] = value_tensor[shm_req.dp_origin_kv_len : kv_len]
@@ -227,7 +226,7 @@ def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple])
227226
trans_info = []
228227
alloc_size = 0
229228
for match, index in my_match:
230-
shm_req, kv_len, value_tensor = match
229+
shm_req, kv_len, value_tensor, _ = match
231230
trans_len = shm_req.dp_max_kv_len - kv_len
232231
if trans_len > 0 and shm_req.dp_max_kv_rank != self.dp_rank_in_node:
233232
move_token_indexes.extend(self.shared_kv_indexes.arr[index, 0:trans_len])

lightllm/utils/log_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
from typing import Optional
99

10-
_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
10+
_FORMAT = "%(levelname)s %(asctime)s,%(msecs)03d [%(filename)s:%(lineno)d] %(message)s"
1111
_DATE_FORMAT = "%m-%d %H:%M:%S"
1212

1313
_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug")

unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33
import random
4-
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node
4+
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node, kv_trans_for_dp
55

66

77
@pytest.mark.parametrize(
@@ -73,5 +73,51 @@ def test_kv_trans_v2_for_d_node(token_num):
7373
return
7474

7575

76+
@pytest.mark.parametrize(
77+
"token_num",
78+
[token_num for token_num in range(5, 10)],
79+
)
80+
def test_kv_trans_for_dp(token_num):
81+
card_num = 8
82+
dp_size_in_node = 4
83+
layer_num = 3
84+
head_num = 2
85+
head_dim = 512
86+
kv_buffer_token_num = 512
87+
rank_in_dp = 1
88+
89+
card_num_per_d = card_num // dp_size_in_node
90+
91+
# 创建多层的 mem,每个 mem 包含所有层的数据
92+
mems = []
93+
for _ in range(card_num):
94+
mems.append(
95+
torch.randn((layer_num, kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
96+
)
97+
98+
input_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda")
99+
input_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)]
100+
input_idx = torch.tensor(input_idx, dtype=torch.int32, device="cuda")
101+
input_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)]
102+
input_dp_idx = torch.tensor(input_dp_idx, dtype=torch.int32, device="cuda")
103+
104+
true_output = torch.zeros((layer_num, kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
105+
test_output = torch.zeros((layer_num, kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
106+
output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")
107+
108+
kv_trans_for_dp(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node, rank_in_dp)
109+
110+
# 验证结果
111+
for dest_token_index, src_token_index, dp_index in zip(
112+
list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy()
113+
):
114+
mem_index = rank_in_dp + dp_index * card_num_per_d
115+
# 所有 layer 都从同一个 mem 的对应层读取
116+
true_output[:, dest_token_index, :, :] = mems[mem_index][:, src_token_index, :, :]
117+
118+
assert torch.equal(true_output, test_output), "kv_trans_for_dp output mismatch"
119+
return
120+
121+
76122
if __name__ == "__main__":
77123
pytest.main()

0 commit comments

Comments
 (0)