Skip to content

Commit 537a3cb

Browse files
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9df016f commit 537a3cb

File tree

10 files changed

+428
-206
lines changed

10 files changed

+428
-206
lines changed

colossalai/kernel/triton/context_attn_unpad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ def _fwd_context_paged_attention_kernel_v2(
338338
X_range = tl.arange(0, KCACHE_X)
339339
# unroll the loop aggressively
340340
for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
341-
offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
342-
offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt
341+
offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
342+
offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt
343343
k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0)
344344
# HACK: KCache must be contiguous in order to apply the following offsets calculation
345345
offsets_kcache = (

colossalai/kernel/triton/flash_decoding.py

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,29 @@
1111
def _flash_decoding_fwd_kernel(
1212
Q, # [batch_size * q_len, head_num, head_dim]
1313
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
14-
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
14+
VCache, # [num_blocks, num_kv_heads, block_size, head_dim],
15+
# or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided
1516
block_tables, # [batch_size, max_blocks_per_sequence]
1617
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
1718
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
1819
kv_seq_len, # [batch_size]
1920
q_len,
2021
batch_size,
22+
kv_group_num,
23+
x,
24+
sm_scale,
2125
stride_qt,
2226
stride_qh,
2327
stride_qd,
24-
stride_cacheb,
25-
stride_cacheh,
26-
stride_cachebs,
27-
stride_cached,
28+
stride_kcb,
29+
stride_kch,
30+
stride_kcsplit_x,
31+
stride_kcs,
32+
stride_kcd,
33+
stride_vcb,
34+
stride_vch,
35+
stride_vcs,
36+
stride_vcd,
2837
stride_bts,
2938
stride_btb,
3039
stride_mid_ot,
@@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel(
3443
stride_mid_o_lset,
3544
stride_mid_o_lseh,
3645
stride_mid_o_lseb,
37-
sm_scale,
38-
KV_GROUPS: tl.constexpr,
3946
BLOCK_KV: tl.constexpr,
4047
BLOCK_SIZE: tl.constexpr,
4148
HEAD_DIM: tl.constexpr,
@@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel(
5764
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
5865
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
5966
return
60-
6167
offsets_dmodel = tl.arange(0, HEAD_DIM)
62-
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
63-
q = tl.load(Q + offsets_q)
68+
offsets_block = tl.arange(0, BLOCK_SIZE)
69+
6470
# block table for the current sequence
6571
block_table_ptr = block_tables + cur_seq_idx * stride_bts
6672
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
@@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel(
7177
)
7278
tl.device_assert(cur_occupied_size >= 0)
7379

74-
cur_kv_head_idx = cur_head_idx // KV_GROUPS
75-
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
76-
K_block_ptr = tl.make_block_ptr(
77-
base=KCache + offset_kvcache,
78-
shape=(cur_occupied_size, HEAD_DIM),
79-
strides=(stride_cachebs, stride_cached),
80-
offsets=(0, 0),
81-
block_shape=(BLOCK_SIZE, HEAD_DIM),
82-
order=(0, 1),
80+
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
81+
q = tl.load(Q + offsets_q)
82+
cur_kv_head_idx = cur_head_idx // kv_group_num
83+
offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch
84+
offsets_k = (
85+
offset_kvcache
86+
+ (offsets_dmodel[None, :] // x) * stride_kcsplit_x
87+
+ (offsets_dmodel[None, :] % x) * stride_kcd
88+
+ offsets_block[:, None] * stride_kcs
8389
)
90+
k_cur_block = tl.load(KCache + offsets_k)
8491
V_block_ptr = tl.make_block_ptr(
8592
base=VCache + offset_kvcache,
8693
shape=(cur_occupied_size, HEAD_DIM),
87-
strides=(stride_cachebs, stride_cached),
94+
strides=(stride_vcs, stride_vcd),
8895
offsets=(0, 0),
8996
block_shape=(BLOCK_SIZE, HEAD_DIM),
9097
order=(0, 1),
9198
)
92-
k_cur_block = tl.load(K_block_ptr)
9399
v_cur_block = tl.load(V_block_ptr)
94100
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
95101
# use block size of the paged/blocked kv cache
@@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel(
100106
# Refer to https://github.com/openai/triton/discussions/895
101107
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
102108
S_ij *= sm_scale
103-
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
109+
S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf"))
104110

105111
m = tl.max(S_ij, 0)
106112
S_ij -= m
@@ -324,6 +330,7 @@ def flash_decoding_attention(
324330
sm_scale: int = None,
325331
kv_group_num: int = 1,
326332
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
333+
use_new_kcache_layout: bool = False,
327334
):
328335
"""
329336
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@@ -349,6 +356,7 @@ def flash_decoding_attention(
349356
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
350357
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
351358
Defaults to 1.
359+
use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False.
352360
353361
Returns:
354362
Output tensor with shape [bsz * q_len, num_heads * head_dim]
@@ -400,13 +408,20 @@ def flash_decoding_attention(
400408

401409
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
402410
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
403-
grid = (
411+
grid = lambda META: (
404412
triton.next_power_of_2(bsz * q_len),
405413
num_heads,
406-
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
414+
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]),
407415
)
408416

409417
if alibi_slopes is not None:
418+
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
419+
# the code (alibi kernel) will be refactored later to avoid code duplication, when
420+
# the whole triton flow with new k cache layout has been supported and tested.
421+
assert (
422+
not use_new_kcache_layout
423+
), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
424+
410425
_alibi_flash_decoding_fwd_kernel[grid](
411426
q,
412427
k_cache,
@@ -441,6 +456,19 @@ def flash_decoding_attention(
441456
HEAD_DIM=head_dim,
442457
)
443458
else:
459+
# For KCache and VCache with the same layout
460+
x = head_dim
461+
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
462+
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
463+
if use_new_kcache_layout:
464+
assert (
465+
k_cache.dim() == 5
466+
and k_cache.shape[1] == v_cache.shape[1]
467+
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
468+
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
469+
x = k_cache.size(-1)
470+
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
471+
444472
_flash_decoding_fwd_kernel[grid](
445473
q,
446474
k_cache,
@@ -451,13 +479,21 @@ def flash_decoding_attention(
451479
kv_seq_len,
452480
q_len,
453481
bsz,
482+
kv_group_num,
483+
x,
484+
sm_scale,
454485
q.stride(0),
455486
q.stride(1),
456487
q.stride(2),
457488
k_cache.stride(0),
458489
k_cache.stride(1),
459-
k_cache.stride(2),
460-
k_cache.stride(3),
490+
kcsplit_x_stride,
491+
kcs_stride,
492+
kcd_stride,
493+
v_cache.stride(0),
494+
v_cache.stride(1),
495+
v_cache.stride(2),
496+
v_cache.stride(3),
461497
block_tables.stride(0),
462498
block_tables.stride(1),
463499
mid_output.stride(0),
@@ -467,8 +503,6 @@ def flash_decoding_attention(
467503
mid_output_lse.stride(0),
468504
mid_output_lse.stride(1),
469505
mid_output_lse.stride(2),
470-
sm_scale,
471-
KV_GROUPS=kv_group_num,
472506
BLOCK_KV=block_size,
473507
BLOCK_SIZE=block_size,
474508
HEAD_DIM=head_dim,

0 commit comments

Comments
 (0)