Skip to content

Commit 9077c5f

Browse files
committed
setup loading of kv block indices and mask within backward_kernel_one_col_block
1 parent e3fdf25 commit 9077c5f

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def forward_kernel(
5959
Q,
6060
K,
6161
V,
62-
KV_block_indices,
63-
KV_block_mask,
62+
kv_block_indices,
63+
kv_block_mask,
6464
Out,
6565
M,
6666
Lse,
@@ -219,14 +219,14 @@ def forward_kernel(
219219
# take care of the selected kv blocks
220220

221221
kv_block_indices_ptrs = (
222-
KV_block_indices +
222+
kv_block_indices +
223223
off_b * stride_kvbl_b +
224224
off_h * stride_kvbl_h +
225225
offs_m * stride_kvbl_m
226226
)
227227

228228
kv_block_mask_ptrs = (
229-
KV_block_mask +
229+
kv_block_mask +
230230
off_b * stride_kvbl_b +
231231
off_h * stride_kvbl_h +
232232
offs_m * stride_kvbl_m
@@ -503,8 +503,6 @@ def backward_kernel_one_col_block(
503503
stride_dqm,
504504
stride_dkn,
505505
stride_dvn,
506-
stride_kvbl_b,
507-
stride_kvbl_h,
508506
stride_kvbl_m,
509507
seqlen_q,
510508
seqlen_k,
@@ -721,6 +719,35 @@ def backward_kernel_one_col_block(
721719
sem = 'relaxed',
722720
)
723721

722+
# handle kv block indices using atomic adds for starters, todo: swap dq and dk/dv loops at some point, semi big refactor
723+
724+
kv_block_indices_ptrs = (
725+
kv_block_indices +
726+
offs_m * stride_kvbl_m
727+
)
728+
729+
kv_block_mask_ptrs = (
730+
kv_block_mask +
731+
offs_m * stride_kvbl_m
732+
)
733+
734+
for off_sel_kv_block in range(NUM_SEL_KV_BLOCKS):
735+
block_indices = tl.load(kv_block_indices_ptrs + off_sel_kv_block)
736+
block_masks = tl.load(kv_block_mask_ptrs + off_sel_kv_block)
737+
738+
blocks_offs_n = block_indices[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
739+
740+
block_k_ptrs = (
741+
K + blocks_offs_n[:, :, None] * stride_kn + offs_d[None, None, :]
742+
)
743+
744+
block_v_ptrs = (
745+
V + blocks_offs_n[:, :, None] * stride_vn + offs_d[None, None, :]
746+
)
747+
748+
block_k = tl.load(block_k_ptrs)
749+
block_v = tl.load(block_v_ptrs)
750+
724751
# # increment pointers
725752
# dq_ptrs += BLOCK * stride_dqm
726753
# q_ptrs += BLOCK * stride_qm
@@ -730,6 +757,7 @@ def backward_kernel_one_col_block(
730757

731758
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
732759
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
760+
733761
backward_store_dk_dv(
734762
dk_ptrs,
735763
dv_ptrs,
@@ -808,6 +836,12 @@ def backward_kernel(
808836
DQ += off_b * stride_dqb + off_h * stride_dqh
809837
DK += off_b * stride_dkb + off_h * stride_dkh
810838
DV += off_b * stride_dvb + off_h * stride_dvh
839+
840+
# offset pointers for batch/head for selected kv block related
841+
842+
kv_block_indices += off_b * stride_kvbl_b + off_h * stride_kvbl_h
843+
kv_block_mask += off_b * stride_kvbl_b + off_h * stride_kvbl_h
844+
811845
# pointer to row-wise quantities in value-like data
812846
D += off_hb * seqlen_q_rounded
813847
LSE += off_hb * seqlen_q_rounded
@@ -836,8 +870,6 @@ def backward_kernel(
836870
stride_dqm,
837871
stride_dkn,
838872
stride_dvn,
839-
stride_kvbl_b,
840-
stride_kvbl_h,
841873
stride_kvbl_m,
842874
seqlen_q,
843875
seqlen_k,
@@ -873,8 +905,6 @@ def backward_kernel(
873905
stride_dqm,
874906
stride_dkn,
875907
stride_dvn,
876-
stride_kvbl_b,
877-
stride_kvbl_h,
878908
stride_kvbl_m,
879909
seqlen_q,
880910
seqlen_k,

0 commit comments

Comments
 (0)