Skip to content

Commit e3fdf25

Browse files
committed
setup strides for triton nsa backwards
1 parent 4e39402 commit e3fdf25

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def is_contiguous(x: Tensor):
5555
}
5656
)
5757
@triton.jit
58-
def _fwd_kernel(
58+
def forward_kernel(
5959
Q,
6060
K,
6161
V,
@@ -356,7 +356,7 @@ def flash_attn_forward(
356356
num_warps = 4 if dim <= 64 else 8
357357
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
358358

359-
_fwd_kernel[grid](
359+
forward_kernel[grid](
360360
q,
361361
k,
362362
v,
@@ -398,7 +398,7 @@ def flash_attn_forward(
398398
return o, lse
399399

400400
@triton.jit
401-
def _bwd_preprocess_do_o_dot(
401+
def backward_preprocess_do_o_dot(
402402
Out,
403403
DO,
404404
Delta,
@@ -450,7 +450,7 @@ def _bwd_preprocess_do_o_dot(
450450
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
451451

452452
@triton.jit
453-
def _bwd_store_dk_dv(
453+
def backward_store_dk_dv(
454454
dk_ptrs,
455455
dv_ptrs,
456456
dk,
@@ -482,7 +482,7 @@ def _bwd_store_dk_dv(
482482

483483

484484
@triton.jit
485-
def _bwd_kernel_one_col_block(
485+
def backward_kernel_one_col_block(
486486
start_n,
487487
Q,
488488
K,
@@ -503,6 +503,9 @@ def _bwd_kernel_one_col_block(
503503
stride_dqm,
504504
stride_dkn,
505505
stride_dvn,
506+
stride_kvbl_b,
507+
stride_kvbl_h,
508+
stride_kvbl_m,
506509
seqlen_q,
507510
seqlen_k,
508511
headdim,
@@ -538,7 +541,7 @@ def _bwd_kernel_one_col_block(
538541
if begin_m >= seqlen_q:
539542
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
540543
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
541-
_bwd_store_dk_dv(
544+
backward_store_dk_dv(
542545
dk_ptrs,
543546
dv_ptrs,
544547
dk,
@@ -727,7 +730,7 @@ def _bwd_kernel_one_col_block(
727730

728731
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
729732
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
730-
_bwd_store_dk_dv(
733+
backward_store_dk_dv(
731734
dk_ptrs,
732735
dv_ptrs,
733736
dk,
@@ -742,7 +745,7 @@ def _bwd_kernel_one_col_block(
742745
)
743746

744747
@triton.jit
745-
def _bwd_kernel(
748+
def backward_kernel(
746749
Q,
747750
K,
748751
V,
@@ -776,6 +779,9 @@ def _bwd_kernel(
776779
stride_dvb,
777780
stride_dvh,
778781
stride_dvn,
782+
stride_kvbl_b,
783+
stride_kvbl_h,
784+
stride_kvbl_m,
779785
nheads,
780786
seqlen_q,
781787
seqlen_k,
@@ -805,10 +811,11 @@ def _bwd_kernel(
805811
# pointer to row-wise quantities in value-like data
806812
D += off_hb * seqlen_q_rounded
807813
LSE += off_hb * seqlen_q_rounded
814+
808815
if not SEQUENCE_PARALLEL:
809816
num_block_n = tl.cdiv(seqlen_k, BLOCK)
810817
for start_n in range(0, num_block_n):
811-
_bwd_kernel_one_col_block(
818+
backward_kernel_one_col_block(
812819
start_n,
813820
Q,
814821
K,
@@ -829,6 +836,9 @@ def _bwd_kernel(
829836
stride_dqm,
830837
stride_dkn,
831838
stride_dvn,
839+
stride_kvbl_b,
840+
stride_kvbl_h,
841+
stride_kvbl_m,
832842
seqlen_q,
833843
seqlen_k,
834844
headdim,
@@ -842,7 +852,7 @@ def _bwd_kernel(
842852
)
843853
else:
844854
start_n = tl.program_id(0)
845-
_bwd_kernel_one_col_block(
855+
backward_kernel_one_col_block(
846856
start_n,
847857
Q,
848858
K,
@@ -863,6 +873,9 @@ def _bwd_kernel(
863873
stride_dqm,
864874
stride_dkn,
865875
stride_dvn,
876+
stride_kvbl_b,
877+
stride_kvbl_h,
878+
stride_kvbl_m,
866879
seqlen_q,
867880
seqlen_k,
868881
headdim,
@@ -913,7 +926,7 @@ def flash_attn_backward(
913926
delta = torch.empty_like(lse)
914927
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK"]), batch * nheads)
915928

916-
_bwd_preprocess_do_o_dot[grid](
929+
backward_preprocess_do_o_dot[grid](
917930
o,
918931
do,
919932
delta,
@@ -935,7 +948,7 @@ def flash_attn_backward(
935948
triton.cdiv(seqlen_k, META["BLOCK"]) if META["SEQUENCE_PARALLEL"] else 1,
936949
batch * nheads,
937950
)
938-
_bwd_kernel[grid](
951+
backward_kernel[grid](
939952
q,
940953
k,
941954
v,
@@ -969,6 +982,9 @@ def flash_attn_backward(
969982
dv.stride(0),
970983
dv.stride(2),
971984
dv.stride(1),
985+
kv_block_indices.stride(0),
986+
kv_block_indices.stride(2),
987+
kv_block_indices.stride(1),
972988
nheads,
973989
seqlen_q,
974990
seqlen_k,

test_triton_nsa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def regular_attend(
8282
k = torch.randn(1, 2, 512, 64).cuda()
8383
v = torch.randn(1, 2, 512, 64).cuda()
8484

85-
indices = torch.zeros(1, 2, 512, 2).long().cuda()
86-
mask = torch.ones(1, 2, 512, 2).bool().cuda()
85+
indices = torch.zeros(1, 2, 512, 1).long().cuda()
86+
mask = torch.ones(1, 2, 512, 1).bool().cuda()
8787

8888
# both regular and nsa pathways `r` and `n`
8989

0 commit comments

Comments
 (0)