Skip to content

Commit 40e092b

Browse files
lfr-0531heyuhhh
authored andcommitted
add paged kt cache (1st commit).
Signed-off-by: Fanrong Li <[email protected]> minnor fix. Signed-off-by: Fanrong Li <[email protected]> fix _single_request_update_kt_cache for vanilla RocketKV. Signed-off-by: Fanrong Li <[email protected]> add paged kt cache to rocketkv trtllm. Signed-off-by: Fanrong Li <[email protected]> fix _single_request_update_kt_cache for trtllm RocketKV. Signed-off-by: Fanrong Li <[email protected]> fix k_snap length. Signed-off-by: Fanrong Li <[email protected]> fix memory issue when using paged kt cache. Signed-off-by: Fanrong Li <[email protected]>
1 parent 61b6a1f commit 40e092b

File tree

7 files changed

+462
-200
lines changed

7 files changed

+462
-200
lines changed

examples/llm-api/rocket.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def parse_arguments():
4141
help="The maximum sequence length.")
4242
parser.add_argument("--max_batch_size",
4343
type=int,
44-
default=4,
44+
default=256,
4545
help="The maximum batch size.")
4646
parser.add_argument("--max_new_tokens",
4747
type=int,
@@ -59,7 +59,7 @@ def parse_arguments():
5959
# KV cache
6060
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
6161

62-
parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
62+
parser.add_argument("--kv_cache_fraction", type=float, default=None)
6363

6464
parser.add_argument('--num_samples', type=int, default=1)
6565

examples/longbench/eval_longbench_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
313313
model=args.model_path,
314314
backend=args.backend,
315315
kv_cache_config=kv_cache_config,
316+
max_batch_size=args.max_batch_size,
316317
attn_backend=args.attention_backend,
317318
sparse_attention_config=sparse_attention_config,
318319
tensor_parallel_size=args.tensor_parallel_size,

tensorrt_llm/_torch/attention_backend/sparse/kernel.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,241 @@ def triton_index_gather(input, indices):
6868
dim_size,
6969
BLOCK_SIZE=1024)
7070
return output
71+
72+
73+
@triton.jit
74+
def _update_kt_cache_ctx_kernel(k_ptr, cache_ptr, block_offsets_ptr,
75+
cum_seq_lens_ptr, cum_kt_seq_lens_ptr,
76+
token_to_batch_map_ptr, num_kv_heads, dim_size,
77+
kt_page_size, tokens_per_block,
78+
max_kt_blocks_per_seq,
79+
BLOCK_SIZE: tl.constexpr):
80+
# get program id
81+
kt_token_idx = tl.program_id(0)
82+
83+
# get params
84+
batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx)
85+
kv_start_idx = tl.load(cum_seq_lens_ptr + batch_idx)
86+
kv_end_idx = tl.load(cum_seq_lens_ptr + batch_idx + 1)
87+
kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx)
88+
local_kt_token_idx = kt_token_idx - kt_start_idx
89+
global_kv_token_idx = kv_start_idx + local_kt_token_idx * kt_page_size
90+
91+
# get offsets
92+
hidden_size = num_kv_heads * dim_size
93+
k_base = k_ptr + global_kv_token_idx * hidden_size
94+
block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
95+
block_idx = tl.load(block_offsets_ptr + block_offset)
96+
token_idx_in_block = local_kt_token_idx % tokens_per_block
97+
cache_base = cache_ptr + (block_idx * tokens_per_block +
98+
token_idx_in_block) * hidden_size * 2
99+
100+
# compute min/max and store kt
101+
for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE):
102+
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
103+
head_idx = hidden_indices // dim_size
104+
dim_idx = hidden_indices % dim_size
105+
dim_mask = hidden_indices < hidden_size
106+
107+
# get k_min and k_max
108+
k_min = tl.full((BLOCK_SIZE, ), float('inf'), dtype=tl.float32)
109+
k_max = tl.full((BLOCK_SIZE, ), float('-inf'), dtype=tl.float32)
110+
for i in range(kt_page_size):
111+
if global_kv_token_idx + i < kv_end_idx:
112+
k = tl.load(k_base + i * hidden_size + hidden_indices,
113+
mask=dim_mask,
114+
other=0.0)
115+
k_min = tl.minimum(k_min, k)
116+
k_max = tl.maximum(k_max, k)
117+
k_min = k_min.to(cache_ptr.dtype.element_ty)
118+
k_max = k_max.to(cache_ptr.dtype.element_ty)
119+
120+
# store k_min and k_max to cache
121+
k_min_offset = cache_base + head_idx * dim_size * 2 + dim_idx
122+
k_max_offset = k_min_offset + dim_size
123+
tl.store(k_min_offset, k_min, mask=dim_mask)
124+
tl.store(k_max_offset, k_max, mask=dim_mask)
125+
126+
127+
@triton.jit
128+
def _update_kt_cache_gen_kernel(k_ptr, cache_ptr, block_offsets_ptr,
129+
seq_lens_ptr, num_kv_heads, dim_size,
130+
kt_page_size, tokens_per_block,
131+
max_kt_blocks_per_seq,
132+
BLOCK_SIZE: tl.constexpr):
133+
# get program id
134+
batch_idx = tl.program_id(0)
135+
head_idx = tl.program_id(1)
136+
137+
# get params
138+
past_key_value_length = tl.load(seq_lens_ptr + batch_idx) - 1
139+
kt_token_idx = past_key_value_length // kt_page_size
140+
kt_token_idx_in_page = past_key_value_length % kt_page_size
141+
142+
# get offsets
143+
hidden_size = num_kv_heads * dim_size
144+
k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size
145+
block_offset = batch_idx * max_kt_blocks_per_seq + kt_token_idx // tokens_per_block
146+
block_idx = tl.load(block_offsets_ptr + block_offset)
147+
kt_token_idx_in_block = kt_token_idx % tokens_per_block
148+
cache_base = cache_ptr + (block_idx * tokens_per_block +
149+
kt_token_idx_in_block) * hidden_size * 2
150+
cache_base += head_idx * dim_size * 2
151+
152+
# update kt cache
153+
for hidden_start in tl.range(0, dim_size, BLOCK_SIZE):
154+
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
155+
dim_mask = hidden_indices < dim_size
156+
157+
# load k
158+
k = tl.load(k_base + hidden_indices, mask=dim_mask, other=0.0)
159+
160+
# load kt cache
161+
kt_mask = dim_mask & (kt_token_idx_in_page > 0)
162+
k_min = tl.load(cache_base + hidden_indices,
163+
mask=kt_mask,
164+
other=float('inf'))
165+
k_max = tl.load(cache_base + hidden_indices + dim_size,
166+
mask=kt_mask,
167+
other=float('-inf'))
168+
k_min = tl.minimum(k_min, k)
169+
k_max = tl.maximum(k_max, k)
170+
k_min = k_min.to(cache_ptr.dtype.element_ty)
171+
k_max = k_max.to(cache_ptr.dtype.element_ty)
172+
173+
# store k_min and k_max to cache
174+
tl.store(cache_base + hidden_indices, k_min, mask=dim_mask)
175+
tl.store(cache_base + hidden_indices + dim_size, k_max, mask=dim_mask)
176+
177+
178+
@triton.jit
179+
def _load_kt_cache_kernel(kt_states_ptr, cache_ptr, block_offsets_ptr,
180+
cum_kt_seq_lens_ptr, token_to_batch_map_ptr,
181+
num_kv_heads, dim_size, tokens_per_block,
182+
max_kt_blocks_per_seq, BLOCK_SIZE: tl.constexpr):
183+
# get program id
184+
kt_token_idx = tl.program_id(0)
185+
186+
# get params
187+
batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx)
188+
kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx)
189+
local_kt_token_idx = kt_token_idx - kt_start_idx
190+
191+
# get offsets
192+
hidden_size = num_kv_heads * dim_size * 2
193+
kt_states_base = kt_states_ptr + kt_token_idx * hidden_size
194+
block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
195+
block_idx = tl.load(block_offsets_ptr + block_offset)
196+
token_idx_in_block = local_kt_token_idx % tokens_per_block
197+
cache_base = cache_ptr + (block_idx * tokens_per_block +
198+
token_idx_in_block) * hidden_size
199+
200+
# load kt cache
201+
for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE):
202+
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
203+
mask = hidden_indices < hidden_size
204+
# load kt cache
205+
kt = tl.load(cache_base + hidden_indices, mask=mask, other=0.0)
206+
# store kt to kt_states
207+
tl.store(kt_states_base + hidden_indices, kt, mask=mask)
208+
209+
210+
def triton_update_kt_cache(k,
211+
kt_cache_tensor,
212+
kt_cache_block_offsets,
213+
seq_lens,
214+
kt_page_size,
215+
tokens_per_block,
216+
max_kt_blocks_per_seq,
217+
update=True):
218+
# inputs:
219+
# k: (total_seq_len, num_kv_heads, head_dim)
220+
# kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim)
221+
# kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq)
222+
# seq_lens: (batch_size)
223+
# kt_page_size: int
224+
# update: bool
225+
226+
# outputs:
227+
# kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim)
228+
229+
# params
230+
batch_size = seq_lens.size(0)
231+
num_kv_heads = k.size(1)
232+
head_dim = k.size(2)
233+
tokens_per_block = kt_cache_tensor.size(1)
234+
num_kt_tokens = (seq_lens + kt_page_size - 1) // kt_page_size
235+
236+
# context
237+
if not update:
238+
total_num_kt_tokens = num_kt_tokens.sum().item()
239+
cum_seq_lens = torch.cumsum(torch.cat([
240+
torch.zeros(1, device='cuda', dtype=torch.long),
241+
seq_lens.to(torch.long)
242+
]),
243+
dim=0)
244+
cum_kt_seq_lens = torch.cumsum(torch.cat([
245+
torch.zeros(1, device='cuda', dtype=torch.long),
246+
num_kt_tokens.to(torch.long)
247+
]),
248+
dim=0)
249+
250+
token_to_batch_map = torch.repeat_interleave(
251+
torch.arange(batch_size,
252+
device='cuda'), repeats=num_kt_tokens).to(torch.long)
253+
grid = (total_num_kt_tokens, )
254+
_update_kt_cache_ctx_kernel[grid](k,
255+
kt_cache_tensor,
256+
kt_cache_block_offsets,
257+
cum_seq_lens,
258+
cum_kt_seq_lens,
259+
token_to_batch_map,
260+
num_kv_heads,
261+
head_dim,
262+
kt_page_size,
263+
tokens_per_block,
264+
max_kt_blocks_per_seq,
265+
BLOCK_SIZE=1024)
266+
return
267+
else:
268+
# generation
269+
# update kt cache
270+
grid = (batch_size, num_kv_heads)
271+
_update_kt_cache_gen_kernel[grid](k,
272+
kt_cache_tensor,
273+
kt_cache_block_offsets,
274+
seq_lens,
275+
num_kv_heads,
276+
head_dim,
277+
kt_page_size,
278+
tokens_per_block,
279+
max_kt_blocks_per_seq,
280+
BLOCK_SIZE=1024)
281+
282+
# load kt cache
283+
total_num_kt_tokens = num_kt_tokens.sum().item()
284+
kt_states = torch.empty(
285+
(total_num_kt_tokens, num_kv_heads, 2 * head_dim),
286+
device='cuda',
287+
dtype=k.dtype)
288+
token_to_batch_map = torch.repeat_interleave(
289+
torch.arange(batch_size,
290+
device='cuda'), repeats=num_kt_tokens).to(torch.long)
291+
cum_kt_seq_lens = torch.cumsum(torch.cat([
292+
torch.zeros(1, device='cuda', dtype=torch.long),
293+
num_kt_tokens.to(torch.long)
294+
]),
295+
dim=0)
296+
grid = (total_num_kt_tokens, )
297+
_load_kt_cache_kernel[grid](kt_states,
298+
kt_cache_tensor,
299+
kt_cache_block_offsets,
300+
cum_kt_seq_lens,
301+
token_to_batch_map,
302+
num_kv_heads,
303+
head_dim,
304+
tokens_per_block,
305+
max_kt_blocks_per_seq,
306+
BLOCK_SIZE=1024)
307+
308+
return kt_states

0 commit comments

Comments
 (0)