@@ -59,6 +59,8 @@ def _fwd_kernel(
5959 Q ,
6060 K ,
6161 V ,
62+ KV_block_indices ,
63+ KV_block_mask ,
6264 Out ,
6365 M ,
6466 Lse ,
@@ -87,6 +89,7 @@ def _fwd_kernel(
8789 EVEN_N : tl .constexpr ,
8890 EVEN_HEADDIM : tl .constexpr ,
8991 BLOCK : tl .constexpr ,
92+ NUM_SEL_KV_BLOCKS : tl .constexpr
9093):
9194 start_m = tl .program_id (0 )
9295 off_hb = tl .program_id (1 )
@@ -243,15 +246,18 @@ def flash_attn_forward(
243246 q ,
244247 k ,
245248 v ,
246- indices ,
247- mask ,
249+ kv_block_indices ,
250+ kv_block_mask ,
248251 block_size = 128
249252):
250253 q , k , v = [x if is_contiguous (x ) else x .contiguous () for x in (q , k , v )]
251254
252255 batch , seqlen_q , nheads , dim = q .shape
253256 _ , seqlen_k , _ , _ = k .shape
254257
258+ num_selected_fine_blocks = kv_block_indices .shape [- 1 ]
259+ assert kv_block_indices .shape == kv_block_mask .shape
260+
255261 assert k .shape == (batch , seqlen_k , nheads , dim )
256262 assert v .shape == (batch , seqlen_k , nheads , dim )
257263 assert dim <= 128 , "only support head dimensions up to 128"
@@ -277,6 +283,8 @@ def flash_attn_forward(
277283 q ,
278284 k ,
279285 v ,
286+ kv_block_indices ,
287+ kv_block_mask ,
280288 o ,
281289 m ,
282290 lse ,
@@ -302,6 +310,7 @@ def flash_attn_forward(
302310 seqlen_k // 32 ,
303311 BLOCK_HEADDIM ,
304312 BLOCK = block_size ,
313+ NUM_SEL_KV_BLOCKS = num_selected_fine_blocks ,
305314 num_warps = num_warps ,
306315 num_stages = 1 ,
307316 )
@@ -398,6 +407,8 @@ def _bwd_kernel_one_col_block(
398407 Q ,
399408 K ,
400409 V ,
410+ kv_block_indices ,
411+ kv_block_mask ,
401412 DO ,
402413 DQ ,
403414 DK ,
@@ -421,6 +432,7 @@ def _bwd_kernel_one_col_block(
421432 EVEN_N : tl .constexpr ,
422433 EVEN_HEADDIM : tl .constexpr ,
423434 BLOCK : tl .constexpr ,
435+ NUM_SEL_KV_BLOCKS : tl .constexpr
424436):
425437 # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
426438 begin_m = ((start_n * BLOCK ) // BLOCK ) * BLOCK
@@ -654,6 +666,8 @@ def _bwd_kernel(
654666 Q ,
655667 K ,
656668 V ,
669+ kv_block_indices ,
670+ kv_block_mask ,
657671 DO ,
658672 DQ ,
659673 DK ,
@@ -695,6 +709,7 @@ def _bwd_kernel(
695709 EVEN_N : tl .constexpr ,
696710 EVEN_HEADDIM : tl .constexpr ,
697711 BLOCK : tl .constexpr ,
712+ NUM_SEL_KV_BLOCKS : tl .constexpr
698713):
699714 off_hb = tl .program_id (1 )
700715 off_b = off_hb // nheads
@@ -718,6 +733,8 @@ def _bwd_kernel(
718733 Q ,
719734 K ,
720735 V ,
736+ kv_block_indices ,
737+ kv_block_mask ,
721738 DO ,
722739 DQ ,
723740 DK ,
@@ -735,12 +752,13 @@ def _bwd_kernel(
735752 seqlen_q ,
736753 seqlen_k ,
737754 headdim ,
738- ATOMIC_ADD = False ,
739- BLOCK_HEADDIM = BLOCK_HEADDIM ,
740- EVEN_M = EVEN_M ,
741- EVEN_N = EVEN_N ,
742- EVEN_HEADDIM = EVEN_HEADDIM ,
743- BLOCK = BLOCK ,
755+ ATOMIC_ADD = False ,
756+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
757+ EVEN_M = EVEN_M ,
758+ EVEN_N = EVEN_N ,
759+ EVEN_HEADDIM = EVEN_HEADDIM ,
760+ BLOCK = BLOCK ,
761+ NUM_SEL_KV_BLOCKS = NUM_SEL_KV_BLOCKS
744762 )
745763 else :
746764 start_n = tl .program_id (0 )
@@ -749,6 +767,8 @@ def _bwd_kernel(
749767 Q ,
750768 K ,
751769 V ,
770+ kv_block_indices ,
771+ kv_block_mask ,
752772 DO ,
753773 DQ ,
754774 DK ,
@@ -766,19 +786,20 @@ def _bwd_kernel(
766786 seqlen_q ,
767787 seqlen_k ,
768788 headdim ,
769- ATOMIC_ADD = True ,
770- BLOCK_HEADDIM = BLOCK_HEADDIM ,
771- EVEN_M = EVEN_M ,
772- EVEN_N = EVEN_N ,
773- EVEN_HEADDIM = EVEN_HEADDIM ,
774- BLOCK = BLOCK ,
789+ ATOMIC_ADD = True ,
790+ BLOCK_HEADDIM = BLOCK_HEADDIM ,
791+ EVEN_M = EVEN_M ,
792+ EVEN_N = EVEN_N ,
793+ EVEN_HEADDIM = EVEN_HEADDIM ,
794+ BLOCK = BLOCK ,
795+ NUM_SEL_KV_BLOCKS = NUM_SEL_KV_BLOCKS
775796 )
776797
777798def flash_attn_backward (
778799 do ,
779800 q , k , v ,
780- indices ,
781- mask ,
801+ kv_block_indices ,
802+ kv_block_mask ,
782803 o ,
783804 lse ,
784805 dq , dk , dv ,
@@ -790,6 +811,10 @@ def flash_attn_backward(
790811
791812 batch , seqlen_q , nheads , dim = q .shape
792813 _ , seqlen_k , _ , _ = k .shape
814+
815+ num_sel_fine_blocks = kv_block_indices .shape [- 1 ]
816+ assert kv_block_indices .shape == kv_block_mask .shape
817+
793818 # assert d in {16, 32, 64, 128}
794819 assert dim <= 128
795820 seqlen_q_rounded = round_up_multiple (seqlen_q , TRITON_BLOCK_SIZE )
@@ -834,6 +859,8 @@ def flash_attn_backward(
834859 q ,
835860 k ,
836861 v ,
862+ kv_block_indices ,
863+ kv_block_mask ,
837864 do ,
838865 dq_accum ,
839866 dk ,
@@ -873,6 +900,7 @@ def flash_attn_backward(
873900 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
874901 BLOCK_HEADDIM ,
875902 BLOCK = block_size ,
903+ NUM_SEL_KV_BLOCKS = num_sel_fine_blocks ,
876904 SEQUENCE_PARALLEL = False ,
877905 EVEN_M = (seqlen_q % block_size ) == 0 ,
878906 EVEN_N = (seqlen_k % block_size ) == 0 ,
0 commit comments