@@ -59,8 +59,8 @@ def forward_kernel(
5959 Q ,
6060 K ,
6161 V ,
62- KV_block_indices ,
63- KV_block_mask ,
62+ kv_block_indices ,
63+ kv_block_mask ,
6464 Out ,
6565 M ,
6666 Lse ,
@@ -219,14 +219,14 @@ def forward_kernel(
219219 # take care of the selected kv blocks
220220
221221 kv_block_indices_ptrs = (
222- KV_block_indices +
222+ kv_block_indices +
223223 off_b * stride_kvbl_b +
224224 off_h * stride_kvbl_h +
225225 offs_m * stride_kvbl_m
226226 )
227227
228228 kv_block_mask_ptrs = (
229- KV_block_mask +
229+ kv_block_mask +
230230 off_b * stride_kvbl_b +
231231 off_h * stride_kvbl_h +
232232 offs_m * stride_kvbl_m
@@ -503,8 +503,6 @@ def backward_kernel_one_col_block(
503503 stride_dqm ,
504504 stride_dkn ,
505505 stride_dvn ,
506- stride_kvbl_b ,
507- stride_kvbl_h ,
508506 stride_kvbl_m ,
509507 seqlen_q ,
510508 seqlen_k ,
@@ -721,6 +719,35 @@ def backward_kernel_one_col_block(
721719 sem = 'relaxed' ,
722720 )
723721
722+ # handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
723+
724+ kv_block_indices_ptrs = (
725+ kv_block_indices +
726+ offs_m * stride_kvbl_m
727+ )
728+
729+ kv_block_mask_ptrs = (
730+ kv_block_mask +
731+ offs_m * stride_kvbl_m
732+ )
733+
734+ for off_sel_kv_block in range (NUM_SEL_KV_BLOCKS ):
735+ block_indices = tl .load (kv_block_indices_ptrs + off_sel_kv_block )
736+ block_masks = tl .load (kv_block_mask_ptrs + off_sel_kv_block )
737+
738+ blocks_offs_n = block_indices [:, None ] * BLOCK + tl .arange (0 , BLOCK )[None , :]
739+
740+ block_k_ptrs = (
741+ K + blocks_offs_n [:, :, None ] * stride_kn + offs_d [None , None , :]
742+ )
743+
744+ block_v_ptrs = (
745+ V + blocks_offs_n [:, :, None ] * stride_vn + offs_d [None , None , :]
746+ )
747+
748+ block_k = tl .load (block_k_ptrs )
749+ block_v = tl .load (block_v_ptrs )
750+
724751 # # increment pointers
725752 # dq_ptrs += BLOCK * stride_dqm
726753 # q_ptrs += BLOCK * stride_qm
@@ -730,6 +757,7 @@ def backward_kernel_one_col_block(
730757
731758 dv_ptrs = DV + (offs_n [:, None ] * stride_dvn + offs_d [None , :])
732759 dk_ptrs = DK + (offs_n [:, None ] * stride_dkn + offs_d [None , :])
760+
733761 backward_store_dk_dv (
734762 dk_ptrs ,
735763 dv_ptrs ,
@@ -808,6 +836,12 @@ def backward_kernel(
808836 DQ += off_b * stride_dqb + off_h * stride_dqh
809837 DK += off_b * stride_dkb + off_h * stride_dkh
810838 DV += off_b * stride_dvb + off_h * stride_dvh
839+
840+ # offset pointers for batch/head for selected kv block related
841+
842+ kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
843+ kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
844+
811845 # pointer to row-wise quantities in value-like data
812846 D += off_hb * seqlen_q_rounded
813847 LSE += off_hb * seqlen_q_rounded
@@ -836,8 +870,6 @@ def backward_kernel(
836870 stride_dqm ,
837871 stride_dkn ,
838872 stride_dvn ,
839- stride_kvbl_b ,
840- stride_kvbl_h ,
841873 stride_kvbl_m ,
842874 seqlen_q ,
843875 seqlen_k ,
@@ -873,8 +905,6 @@ def backward_kernel(
873905 stride_dqm ,
874906 stride_dkn ,
875907 stride_dvn ,
876- stride_kvbl_b ,
877- stride_kvbl_h ,
878908 stride_kvbl_m ,
879909 seqlen_q ,
880910 seqlen_k ,
0 commit comments