Skip to content

Commit 4ac117a

Browse files
committed
pass in the fine attention mask to the triton nsa
1 parent 19d0e66 commit 4ac117a

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,13 @@ def forward(
444444
if self.use_triton_kernel:
445445
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
446446

447+
fmask = selected_importance_values > 1e-10
448+
447449
fine_attn_out = native_sparse_attend(
448450
fq, fk, fv,
449451
self.selection_block_size,
450452
selected_block_indices,
453+
fmask,
451454
fine_num_grouped_queries
452455
)
453456

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
22
# with fixes for triton 2.3
3-
# forward is modified to return unnormalized accumulation, row maxes, row lse - reduced over passed rings
4-
# both forwards and backwards is modified to allow for masking out the diagonal for striped ring attention
53

64
from functools import partial
75
import math
@@ -24,6 +22,8 @@ def round_up_multiple(n, mult):
2422
def is_contiguous(x: Tensor):
2523
return x.stride(-1) == 1
2624

25+
TRITON_BLOCK_SIZE = 128 # some block size that allows triton not to break, at least half a year ago
26+
2727
INSTALL_COMMAND = 'pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly'
2828

2929
# make sure triton 2.1+ is installed
@@ -45,10 +45,6 @@ def is_contiguous(x: Tensor):
4545
import triton.language as tl
4646
from triton.language.extra import libdevice
4747

48-
# constants
49-
50-
TRITON_BLOCK_SIZE = 128
51-
5248
# kernels
5349

5450
@triton.heuristics(
@@ -784,8 +780,8 @@ def flash_attn_backward(
784780
assert all([is_contiguous(t) for t in (q, k, v, o, dq, dk, dv)])
785781

786782
softmax_scale = dim ** -0.5
787-
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
788-
dq_accum = torch.empty_like(q, dtype=torch.float32)
783+
784+
dq_accum = torch.empty_like(q, dtype = torch.float32)
789785

790786
# delta = torch.zeros_like(lse)
791787

@@ -812,9 +808,6 @@ def flash_attn_backward(
812808
BLOCK_HEADDIM=BLOCK_HEADDIM,
813809
)
814810

815-
# BLOCK_M = 128
816-
# BLOCK_N = 64
817-
# num_warps = 4
818811
grid = lambda META: (
819812
triton.cdiv(seqlen_k, META["BLOCK"]) if META["SEQUENCE_PARALLEL"] else 1,
820813
batch * nheads,
@@ -887,6 +880,7 @@ def forward(
887880
fq, fk, fv,
888881
block_size,
889882
selected_block_indices,
883+
fmask,
890884
num_grouped_queries
891885
):
892886
fq, fk, fv = tuple(rearrange(t, 'b h n d -> b n h d') for t in (fq, fk, fv))
@@ -897,7 +891,7 @@ def forward(
897891

898892
out, lse = flash_attn_forward(fq, fk, fv, block_size = block_size)
899893

900-
ctx.save_for_backward(fq, fk, fv, out, lse)
894+
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
901895
ctx._saved_variables = (block_size,)
902896

903897
out = rearrange(out, 'b n h d -> b h n d')
@@ -907,7 +901,7 @@ def forward(
907901
def backward(self, ctx, do):
908902
do = rearrange(do, 'b h n d -> b n h d')
909903

910-
q, k, v, out, lse = ctx.saved_tensors
904+
q, k, v, kv_indices, mask, out, lse = ctx.saved_tensors
911905

912906
(
913907
block_size,
@@ -921,6 +915,6 @@ def backward(self, ctx, do):
921915
flash_attn_backward(do, q, k, v, out, lse, dq, dk, dv, block_size = block_size)
922916

923917
dq, dk, dv = tuple(rearrange(t, 'b n h d -> b h n d') for t in (dq, dk, dv))
924-
return dq, dk, dv, None, None, None
918+
return dq, dk, dv, None, None, None, None
925919

926920
native_sparse_attend = NSA.apply

test_triton_nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def regular_attend(q, k, v, block_size = None):
3838
out = regular_attend(rq, rk, rv, block_size = 64)
3939
out.sum().backward()
4040

41-
nsa_out = native_sparse_attend(nq, nk, nv, 64, None, 1)
41+
nsa_out = native_sparse_attend(nq, nk, nv, 64, None, None, 1)
4242
nsa_out.sum().backward()
4343

4444
assert torch.allclose(out, nsa_out, atol = 1e-2)

0 commit comments

Comments
 (0)