@@ -55,7 +55,7 @@ def is_contiguous(x: Tensor):
5555 }
5656)
5757@triton .jit
58- def _fwd_kernel (
58+ def forward_kernel (
5959 Q ,
6060 K ,
6161 V ,
@@ -356,7 +356,7 @@ def flash_attn_forward(
356356 num_warps = 4 if dim <= 64 else 8
357357 grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
358358
359- _fwd_kernel [grid ](
359+ forward_kernel [grid ](
360360 q ,
361361 k ,
362362 v ,
@@ -398,7 +398,7 @@ def flash_attn_forward(
398398 return o , lse
399399
400400@triton .jit
401- def _bwd_preprocess_do_o_dot (
401+ def backward_preprocess_do_o_dot (
402402 Out ,
403403 DO ,
404404 Delta ,
@@ -450,7 +450,7 @@ def _bwd_preprocess_do_o_dot(
450450 tl .store (Delta + off_hb * seqlen_q_rounded + offs_m , delta )
451451
452452@triton .jit
453- def _bwd_store_dk_dv (
453+ def backward_store_dk_dv (
454454 dk_ptrs ,
455455 dv_ptrs ,
456456 dk ,
@@ -482,7 +482,7 @@ def _bwd_store_dk_dv(
482482
483483
484484@triton .jit
485- def _bwd_kernel_one_col_block (
485+ def backward_kernel_one_col_block (
486486 start_n ,
487487 Q ,
488488 K ,
@@ -503,6 +503,9 @@ def _bwd_kernel_one_col_block(
503503 stride_dqm ,
504504 stride_dkn ,
505505 stride_dvn ,
506+ stride_kvbl_b ,
507+ stride_kvbl_h ,
508+ stride_kvbl_m ,
506509 seqlen_q ,
507510 seqlen_k ,
508511 headdim ,
@@ -538,7 +541,7 @@ def _bwd_kernel_one_col_block(
538541 if begin_m >= seqlen_q :
539542 dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
540543 dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
541- _bwd_store_dk_dv (
544+ backward_store_dk_dv (
542545 dk_ptrs ,
543546 dv_ptrs ,
544547 dk ,
@@ -727,7 +730,7 @@ def _bwd_kernel_one_col_block(
727730
728731 dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
729732 dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
730- _bwd_store_dk_dv (
733+ backward_store_dk_dv (
731734 dk_ptrs ,
732735 dv_ptrs ,
733736 dk ,
@@ -742,7 +745,7 @@ def _bwd_kernel_one_col_block(
742745 )
743746
744747@triton .jit
745- def _bwd_kernel (
748+ def backward_kernel (
746749 Q ,
747750 K ,
748751 V ,
@@ -776,6 +779,9 @@ def _bwd_kernel(
776779 stride_dvb ,
777780 stride_dvh ,
778781 stride_dvn ,
782+ stride_kvbl_b ,
783+ stride_kvbl_h ,
784+ stride_kvbl_m ,
779785 nheads ,
780786 seqlen_q ,
781787 seqlen_k ,
@@ -805,10 +811,11 @@ def _bwd_kernel(
805811 # pointer to row-wise quantities in value-like data
806812 D += off_hb * seqlen_q_rounded
807813 LSE += off_hb * seqlen_q_rounded
814+
808815 if not SEQUENCE_PARALLEL :
809816 num_block_n = tl .cdiv (seqlen_k , BLOCK )
810817 for start_n in range (0 , num_block_n ):
811- _bwd_kernel_one_col_block (
818+ backward_kernel_one_col_block (
812819 start_n ,
813820 Q ,
814821 K ,
@@ -829,6 +836,9 @@ def _bwd_kernel(
829836 stride_dqm ,
830837 stride_dkn ,
831838 stride_dvn ,
839+ stride_kvbl_b ,
840+ stride_kvbl_h ,
841+ stride_kvbl_m ,
832842 seqlen_q ,
833843 seqlen_k ,
834844 headdim ,
@@ -842,7 +852,7 @@ def _bwd_kernel(
842852 )
843853 else :
844854 start_n = tl .program_id (0 )
845- _bwd_kernel_one_col_block (
855+ backward_kernel_one_col_block (
846856 start_n ,
847857 Q ,
848858 K ,
@@ -863,6 +873,9 @@ def _bwd_kernel(
863873 stride_dqm ,
864874 stride_dkn ,
865875 stride_dvn ,
876+ stride_kvbl_b ,
877+ stride_kvbl_h ,
878+ stride_kvbl_m ,
866879 seqlen_q ,
867880 seqlen_k ,
868881 headdim ,
@@ -913,7 +926,7 @@ def flash_attn_backward(
913926 delta = torch .empty_like (lse )
914927 grid = lambda META : (triton .cdiv (seqlen_q , META ["BLOCK" ]), batch * nheads )
915928
916- _bwd_preprocess_do_o_dot [grid ](
929+ backward_preprocess_do_o_dot [grid ](
917930 o ,
918931 do ,
919932 delta ,
@@ -935,7 +948,7 @@ def flash_attn_backward(
935948 triton .cdiv (seqlen_k , META ["BLOCK" ]) if META ["SEQUENCE_PARALLEL" ] else 1 ,
936949 batch * nheads ,
937950 )
938- _bwd_kernel [grid ](
951+ backward_kernel [grid ](
939952 q ,
940953 k ,
941954 v ,
@@ -969,6 +982,9 @@ def flash_attn_backward(
969982 dv .stride (0 ),
970983 dv .stride (2 ),
971984 dv .stride (1 ),
985+ kv_block_indices .stride (0 ),
986+ kv_block_indices .stride (2 ),
987+ kv_block_indices .stride (1 ),
972988 nheads ,
973989 seqlen_q ,
974990 seqlen_k ,
0 commit comments