11
11
def _flash_decoding_fwd_kernel (
12
12
Q , # [batch_size * q_len, head_num, head_dim]
13
13
KCache , # [num_blocks, num_kv_heads, block_size, head_dim]
14
- VCache , # [num_blocks, num_kv_heads, block_size, head_dim]
14
+ VCache , # [num_blocks, num_kv_heads, block_size, head_dim],
15
+ # or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided
15
16
block_tables , # [batch_size, max_blocks_per_sequence]
16
17
mid_o , # [batch_size * q_len, head_num, kv_split_num, head_dim]
17
18
mid_o_lse , # [batch_size * q_len, head_num, kv_split_num]
18
19
kv_seq_len , # [batch_size]
19
20
q_len ,
20
21
batch_size ,
22
+ kv_group_num ,
23
+ x ,
24
+ sm_scale ,
21
25
stride_qt ,
22
26
stride_qh ,
23
27
stride_qd ,
24
- stride_cacheb ,
25
- stride_cacheh ,
26
- stride_cachebs ,
27
- stride_cached ,
28
+ stride_kcb ,
29
+ stride_kch ,
30
+ stride_kcsplit_x ,
31
+ stride_kcs ,
32
+ stride_kcd ,
33
+ stride_vcb ,
34
+ stride_vch ,
35
+ stride_vcs ,
36
+ stride_vcd ,
28
37
stride_bts ,
29
38
stride_btb ,
30
39
stride_mid_ot ,
@@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel(
34
43
stride_mid_o_lset ,
35
44
stride_mid_o_lseh ,
36
45
stride_mid_o_lseb ,
37
- sm_scale ,
38
- KV_GROUPS : tl .constexpr ,
39
46
BLOCK_KV : tl .constexpr ,
40
47
BLOCK_SIZE : tl .constexpr ,
41
48
HEAD_DIM : tl .constexpr ,
@@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel(
57
64
cur_kv_seq_len = tl .load (kv_seq_len + cur_seq_idx ) + cur_token_off
58
65
if block_start_kv * BLOCK_KV >= cur_kv_seq_len :
59
66
return
60
-
61
67
offsets_dmodel = tl .arange (0 , HEAD_DIM )
62
- offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
63
- q = tl . load ( Q + offsets_q )
68
+ offsets_block = tl . arange ( 0 , BLOCK_SIZE )
69
+
64
70
# block table for the current sequence
65
71
block_table_ptr = block_tables + cur_seq_idx * stride_bts
66
72
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
@@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel(
71
77
)
72
78
tl .device_assert (cur_occupied_size >= 0 )
73
79
74
- cur_kv_head_idx = cur_head_idx // KV_GROUPS
75
- offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
76
- K_block_ptr = tl . make_block_ptr (
77
- base = KCache + offset_kvcache ,
78
- shape = ( cur_occupied_size , HEAD_DIM ),
79
- strides = ( stride_cachebs , stride_cached ),
80
- offsets = ( 0 , 0 ),
81
- block_shape = ( BLOCK_SIZE , HEAD_DIM ),
82
- order = ( 0 , 1 ),
80
+ offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
81
+ q = tl . load ( Q + offsets_q )
82
+ cur_kv_head_idx = cur_head_idx // kv_group_num
83
+ offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch
84
+ offsets_k = (
85
+ offset_kvcache
86
+ + ( offsets_dmodel [ None , :] // x ) * stride_kcsplit_x
87
+ + ( offsets_dmodel [ None , :] % x ) * stride_kcd
88
+ + offsets_block [:, None ] * stride_kcs
83
89
)
90
+ k_cur_block = tl .load (KCache + offsets_k )
84
91
V_block_ptr = tl .make_block_ptr (
85
92
base = VCache + offset_kvcache ,
86
93
shape = (cur_occupied_size , HEAD_DIM ),
87
- strides = (stride_cachebs , stride_cached ),
94
+ strides = (stride_vcs , stride_vcd ),
88
95
offsets = (0 , 0 ),
89
96
block_shape = (BLOCK_SIZE , HEAD_DIM ),
90
97
order = (0 , 1 ),
91
98
)
92
- k_cur_block = tl .load (K_block_ptr )
93
99
v_cur_block = tl .load (V_block_ptr )
94
100
acc = tl .zeros ([HEAD_DIM ], dtype = tl .float32 )
95
101
# use block size of the paged/blocked kv cache
@@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel(
100
106
# Refer to https://github.com/openai/triton/discussions/895
101
107
S_ij += tl .sum (q [None , :] * k_cur_block , 1 )
102
108
S_ij *= sm_scale
103
- S_ij += tl .where (block_start_kv * BLOCK_KV + tl . arange ( 0 , BLOCK_SIZE ) < cur_kv_seq_len , 0 , float ("-inf" ))
109
+ S_ij += tl .where (block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len , 0 , float ("-inf" ))
104
110
105
111
m = tl .max (S_ij , 0 )
106
112
S_ij -= m
@@ -324,6 +330,7 @@ def flash_decoding_attention(
324
330
sm_scale : int = None ,
325
331
kv_group_num : int = 1 ,
326
332
q_len : int = 1 , # NOTE alibi flash decoding does not support q_len > 1 at this moment.
333
+ use_new_kcache_layout : bool = False ,
327
334
):
328
335
"""
329
336
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@@ -349,6 +356,7 @@ def flash_decoding_attention(
349
356
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
350
357
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
351
358
Defaults to 1.
359
+ use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False.
352
360
353
361
Returns:
354
362
Output tensor with shape [bsz * q_len, num_heads * head_dim]
@@ -400,13 +408,20 @@ def flash_decoding_attention(
400
408
401
409
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
402
410
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
403
- grid = (
411
+ grid = lambda META : (
404
412
triton .next_power_of_2 (bsz * q_len ),
405
413
num_heads ,
406
- triton .cdiv (triton .next_power_of_2 (max_seq_len_in_batch ), BLOCK_KV ),
414
+ triton .cdiv (triton .next_power_of_2 (max_seq_len_in_batch ), META [ " BLOCK_KV" ] ),
407
415
)
408
416
409
417
if alibi_slopes is not None :
418
+ # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
419
+ # the code (alibi kernel) will be refactored later to avoid code duplication, when
420
+ # the whole triton flow with new k cache layout has been supported and tested.
421
+ assert (
422
+ not use_new_kcache_layout
423
+ ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
424
+
410
425
_alibi_flash_decoding_fwd_kernel [grid ](
411
426
q ,
412
427
k_cache ,
@@ -441,6 +456,19 @@ def flash_decoding_attention(
441
456
HEAD_DIM = head_dim ,
442
457
)
443
458
else :
459
+ # For KCache and VCache with the same layout
460
+ x = head_dim
461
+ kcsplit_x_stride , kcs_stride , kcd_stride = 0 , k_cache .stride (2 ), k_cache .stride (3 )
462
+ # For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
463
+ if use_new_kcache_layout :
464
+ assert (
465
+ k_cache .dim () == 5
466
+ and k_cache .shape [1 ] == v_cache .shape [1 ]
467
+ and k_cache .shape [2 ] * k_cache .shape [4 ] == v_cache .shape [3 ]
468
+ ), f"Invalid KCache shape { k_cache .shape } and VCache shape { v_cache .shape } "
469
+ x = k_cache .size (- 1 )
470
+ kcsplit_x_stride , kcs_stride , kcd_stride = k_cache .stride ()[- 3 :]
471
+
444
472
_flash_decoding_fwd_kernel [grid ](
445
473
q ,
446
474
k_cache ,
@@ -451,13 +479,21 @@ def flash_decoding_attention(
451
479
kv_seq_len ,
452
480
q_len ,
453
481
bsz ,
482
+ kv_group_num ,
483
+ x ,
484
+ sm_scale ,
454
485
q .stride (0 ),
455
486
q .stride (1 ),
456
487
q .stride (2 ),
457
488
k_cache .stride (0 ),
458
489
k_cache .stride (1 ),
459
- k_cache .stride (2 ),
460
- k_cache .stride (3 ),
490
+ kcsplit_x_stride ,
491
+ kcs_stride ,
492
+ kcd_stride ,
493
+ v_cache .stride (0 ),
494
+ v_cache .stride (1 ),
495
+ v_cache .stride (2 ),
496
+ v_cache .stride (3 ),
461
497
block_tables .stride (0 ),
462
498
block_tables .stride (1 ),
463
499
mid_output .stride (0 ),
@@ -467,8 +503,6 @@ def flash_decoding_attention(
467
503
mid_output_lse .stride (0 ),
468
504
mid_output_lse .stride (1 ),
469
505
mid_output_lse .stride (2 ),
470
- sm_scale ,
471
- KV_GROUPS = kv_group_num ,
472
506
BLOCK_KV = block_size ,
473
507
BLOCK_SIZE = block_size ,
474
508
HEAD_DIM = head_dim ,
0 commit comments