Skip to content

Commit 1911b34

Browse files
author
wangzaijun
committed
fix
1 parent da76a05 commit 1911b34

File tree

2 files changed

+176
-40
lines changed

2 files changed

+176
-40
lines changed

lightllm/common/basemodel/triton_kernel/kv_cache_offload.py

Lines changed: 171 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import triton
44
import triton.language as tl
5-
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
65

76

87
@triton.jit
@@ -22,9 +21,16 @@ def _offload_gpu_kv_to_cpu(
2221
page_indexes_ptr,
2322
page_readies_ptr,
2423
layer_num,
25-
head_all_dim,
26-
cpu_head_offset,
27-
BLOCK_HEAD_ALL_DIM: tl.constexpr,
24+
head_dim,
25+
cpu_k_start_head_index: tl.constexpr,
26+
cpu_k_head_num: tl.constexpr,
27+
gpu_k_start_head_index: tl.constexpr,
28+
gpu_k_head_num: tl.constexpr,
29+
cpu_v_start_head_index: tl.constexpr,
30+
cpu_v_head_num: tl.constexpr,
31+
gpu_v_start_head_index: tl.constexpr,
32+
gpu_v_head_num: tl.constexpr,
33+
BLOCK_HEAD_DIM: tl.constexpr,
2834
TOKEN_BLOCK: tl.constexpr,
2935
):
3036
block_index = tl.program_id(0)
@@ -38,28 +44,61 @@ def _offload_gpu_kv_to_cpu(
3844

3945
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
4046
token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64)
41-
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
47+
head_dim_range = tl.arange(0, BLOCK_HEAD_DIM)
48+
head_dim_mask = head_dim_range < head_dim
4249

4350
for layer_index in range(layer_num):
44-
gpu_ptr = (
45-
gpu_kv_cache_ptr
46-
+ layer_index.to(tl.int64) * gpu_stride0
47-
+ token_indexes[:, None] * gpu_stride1
48-
+ head_all_dim_range[None, :]
49-
)
50-
gpu_data = tl.load(gpu_ptr, mask=(head_all_dim_range[None, :] < head_all_dim), other=0.0)
51-
cpu_ptr = (
52-
cpu_kv_cache_ptr
53-
+ cpu_page_index * cpu_stride0
54-
+ layer_index * cpu_stride1
55-
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
56-
+ (cpu_head_offset + head_all_dim_range[None, :])
57-
)
58-
tl.store(
59-
cpu_ptr,
60-
gpu_data,
61-
mask=(head_all_dim_range[None, :] < head_all_dim),
62-
)
51+
for k_head_index in range(gpu_k_head_num):
52+
gpu_k_head_index = k_head_index + gpu_k_start_head_index
53+
cpu_k_head_index = k_head_index + cpu_k_start_head_index
54+
55+
gpu_ptr = (
56+
gpu_kv_cache_ptr
57+
+ layer_index.to(tl.int64) * gpu_stride0
58+
+ token_indexes[:, None] * gpu_stride1
59+
+ gpu_k_head_index.to(tl.int64) * gpu_stride2
60+
+ head_dim_range[None, :]
61+
)
62+
gpu_data = tl.load(gpu_ptr, mask=head_dim_mask[None, :], other=0.0)
63+
cpu_ptr = (
64+
cpu_kv_cache_ptr
65+
+ cpu_page_index * cpu_stride0
66+
+ layer_index * cpu_stride1
67+
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
68+
+ cpu_k_head_index * cpu_stride3
69+
+ head_dim_range[None, :]
70+
)
71+
tl.store(
72+
cpu_ptr,
73+
gpu_data,
74+
mask=head_dim_mask[None, :],
75+
)
76+
77+
for v_head_index in range(gpu_v_head_num):
78+
gpu_v_head_index = v_head_index + gpu_v_start_head_index
79+
cpu_v_head_index = v_head_index + cpu_v_start_head_index
80+
81+
gpu_ptr = (
82+
gpu_kv_cache_ptr
83+
+ layer_index.to(tl.int64) * gpu_stride0
84+
+ token_indexes[:, None] * gpu_stride1
85+
+ gpu_v_head_index.to(tl.int64) * gpu_stride2
86+
+ head_dim_range[None, :]
87+
)
88+
gpu_data = tl.load(gpu_ptr, mask=head_dim_mask[None, :], other=0.0)
89+
cpu_ptr = (
90+
cpu_kv_cache_ptr
91+
+ cpu_page_index * cpu_stride0
92+
+ layer_index * cpu_stride1
93+
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
94+
+ cpu_v_head_index * cpu_stride3
95+
+ head_dim_range[None, :]
96+
)
97+
tl.store(
98+
cpu_ptr,
99+
gpu_data,
100+
mask=head_dim_mask[None, :],
101+
)
63102
return
64103

65104

@@ -70,6 +109,9 @@ def offload_gpu_kv_to_cpu(
70109
cpu_kv_cache: torch.Tensor,
71110
page_indexes: torch.Tensor,
72111
page_readies: torch.Tensor,
112+
tp_index: int,
113+
tp_world_size: int,
114+
_cache_data={},
73115
):
74116
"""
75117
this function is used to offload GPU KV cache to CPU KV cache.
@@ -81,25 +123,108 @@ def offload_gpu_kv_to_cpu(
81123
page_readies: (page_num,)
82124
"""
83125
token_block_size = cpu_kv_cache.shape[2]
84-
token_num = page_indexes.shape[0] * token_block_size
85-
assert token_indexes.shape[0] >= token_num
126+
token_num = token_indexes.shape[0]
127+
assert token_num == page_indexes.shape[0] * token_block_size
86128
assert page_indexes.shape == page_readies.shape
87-
page_num = page_indexes.shape[0]
88-
head_all_dim = gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2]
89-
BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2])
90129

91-
# Calculate head offset for tensor parallelism
92-
tp_rank = get_current_rank_in_dp()
93-
tp_num = get_dp_world_size()
94130
gpu_heads = gpu_kv_cache.shape[2]
95131
gpu_head_dim = gpu_kv_cache.shape[3]
96132
cpu_heads = cpu_kv_cache.shape[3]
97-
factor = (tp_num * gpu_heads) // cpu_heads
98-
cpu_head_offset = (tp_rank // factor) * gpu_heads * gpu_head_dim
99-
if tp_rank % factor != 0:
100-
# redundant kv does not need to offload
133+
cpu_head_dim = cpu_kv_cache.shape[4]
134+
assert gpu_head_dim == cpu_head_dim
135+
head_dim = gpu_head_dim
136+
scale_size = (tp_world_size * gpu_heads) // cpu_heads
137+
138+
# 计算需要拷贝的 head 索引的对应关系
139+
if (gpu_heads, cpu_heads, tp_index, tp_world_size) in _cache_data:
140+
need_offload, head_info_tuple = _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)]
141+
else:
142+
if cpu_heads > 1:
143+
assert (tp_world_size * gpu_heads) % cpu_heads == 0
144+
assert cpu_heads % 2 == 0
145+
146+
cpu_heads_index = (
147+
torch.arange(0, cpu_heads, device="cpu", dtype=torch.int32)
148+
.view(cpu_heads, 1)
149+
.tile((1, scale_size))
150+
.view(2, tp_world_size, -1)
151+
)
152+
# k
153+
k_cpu_heads_index = cpu_heads_index[0][tp_index]
154+
# v
155+
v_cpu_heads_index = cpu_heads_index[1][tp_index]
156+
157+
cpu_heads_index = torch.cat([k_cpu_heads_index, v_cpu_heads_index], dim=0).view(2, -1).numpy()
158+
gpu_heads_index = torch.arange(0, gpu_heads, device="cpu", dtype=torch.int32).view(2, -1).numpy()
159+
160+
need_offload = tp_index % scale_size == 0
161+
162+
cpu_k_start_head_index = cpu_heads_index[0, 0]
163+
cpu_k_head_num = len(cpu_heads_index[0])
164+
gpu_k_start_head_index = gpu_heads_index[0, 0]
165+
gpu_k_head_num = len(gpu_heads_index[0])
166+
assert cpu_k_head_num == gpu_k_head_num
167+
cpu_v_start_head_index = cpu_heads_index[1, 0]
168+
cpu_v_head_num = len(cpu_heads_index[1])
169+
gpu_v_start_head_index = gpu_heads_index[1, 0]
170+
gpu_v_head_num = len(gpu_heads_index[1])
171+
assert cpu_v_head_num == gpu_v_head_num
172+
173+
head_info_tuple = (
174+
cpu_k_start_head_index,
175+
cpu_k_head_num,
176+
gpu_k_start_head_index,
177+
gpu_k_head_num,
178+
cpu_v_start_head_index,
179+
cpu_v_head_num,
180+
gpu_v_start_head_index,
181+
gpu_v_head_num,
182+
)
183+
184+
else:
185+
assert gpu_heads == 1
186+
assert cpu_heads == 1
187+
188+
need_offload = tp_index == 0
189+
190+
cpu_k_start_head_index = 0
191+
cpu_k_head_num = 1
192+
gpu_k_start_head_index = 0
193+
gpu_k_head_num = 1
194+
cpu_v_start_head_index = 0
195+
cpu_v_head_num = 0
196+
gpu_v_start_head_index = 0
197+
gpu_v_head_num = 0
198+
head_info_tuple = (
199+
cpu_k_start_head_index,
200+
cpu_k_head_num,
201+
gpu_k_start_head_index,
202+
gpu_k_head_num,
203+
cpu_v_start_head_index,
204+
cpu_v_head_num,
205+
gpu_v_start_head_index,
206+
gpu_v_head_num,
207+
)
208+
209+
_cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] = (need_offload, head_info_tuple)
210+
211+
(
212+
cpu_k_start_head_index,
213+
cpu_k_head_num,
214+
gpu_k_start_head_index,
215+
gpu_k_head_num,
216+
cpu_v_start_head_index,
217+
cpu_v_head_num,
218+
gpu_v_start_head_index,
219+
gpu_v_head_num,
220+
) = head_info_tuple
221+
222+
if not need_offload:
101223
return
102224

225+
assert token_block_size == triton.next_power_of_2(token_block_size)
226+
page_num = page_indexes.shape[0]
227+
103228
grid = (page_num,)
104229
num_warps = 4
105230

@@ -119,9 +244,16 @@ def offload_gpu_kv_to_cpu(
119244
page_indexes_ptr=page_indexes,
120245
page_readies_ptr=page_readies,
121246
layer_num=gpu_kv_cache.shape[0],
122-
head_all_dim=head_all_dim,
123-
cpu_head_offset=cpu_head_offset,
124-
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
247+
head_dim=head_dim,
248+
cpu_k_start_head_index=cpu_v_start_head_index,
249+
cpu_k_head_num=cpu_k_head_num,
250+
gpu_k_start_head_index=gpu_k_start_head_index,
251+
gpu_k_head_num=gpu_k_head_num,
252+
cpu_v_start_head_index=cpu_v_start_head_index,
253+
cpu_v_head_num=cpu_v_head_num,
254+
gpu_v_start_head_index=gpu_v_start_head_index,
255+
gpu_v_head_num=gpu_v_head_num,
256+
BLOCK_HEAD_DIM=triton.next_power_of_2(head_dim),
125257
TOKEN_BLOCK=token_block_size,
126258
num_warps=num_warps,
127259
num_stages=1,

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,17 @@ def _start_kv_cache_offload_task(
190190

191191
page_indexes = torch.tensor(page_list, dtype=torch.int32, device="cpu", pin_memory=True)
192192
page_readies = torch.tensor(ready_list, dtype=torch.bool, device="cpu", pin_memory=True)
193-
token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0 : req.cur_kv_len]
193+
move_token_num = item_size * self.args.cpu_cache_token_page_size
194+
assert req.cur_kv_len >= item_size * self.args.cpu_cache_token_page_size
195+
token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:move_token_num]
194196
offload_gpu_kv_to_cpu(
195197
token_indexes=token_indexes,
196198
gpu_kv_cache=self.backend.model.mem_manager.kv_buffer,
197199
cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor,
198200
page_indexes=page_indexes,
199201
page_readies=page_readies,
202+
tp_index=self.backend.rank_in_dp,
203+
tp_world_size=self.backend.dp_world_size,
200204
)
201205

202206
sync_event = torch.cuda.Event()

0 commit comments

Comments
 (0)