Skip to content

Commit 2c3ad83

Browse files
committed
prepare test file for fusing sliding window with NSA
1 parent 571080b commit 2c3ad83

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -640,9 +640,16 @@ def backward_preprocess_do_o_dot(
640640
# load
641641

642642
o = tl.load(
643-
Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
644-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
645-
other=0.0,
643+
Out +
644+
off_b * stride_ob +
645+
off_h * stride_oh +
646+
offs_m[:, None] * stride_om +
647+
offs_d[None, :],
648+
mask = (
649+
(offs_m[:, None] < seqlen_q) &
650+
(offs_d[None, :] < headdim)
651+
),
652+
other = 0.0,
646653
).to(tl.float32)
647654

648655
do = tl.load(
@@ -651,7 +658,10 @@ def backward_preprocess_do_o_dot(
651658
+ off_h * stride_doh
652659
+ offs_m[:, None] * stride_dom
653660
+ offs_d[None, :],
654-
mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
661+
mask = (
662+
offs_m[:, None] < seqlen_q) &
663+
(offs_d[None, :] < headdim
664+
),
655665
other = 0.0,
656666
).to(tl.float32)
657667

@@ -1189,8 +1199,8 @@ def backward_kernel_one_col_block_causal(
11891199
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
11901200
do = tl.load(
11911201
do_ptrs,
1192-
mask=(offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
1193-
other=0.0,
1202+
mask = (offs_m[None, :, None] < seqlen_q) & (offs_d[None, None, :] < headdim),
1203+
other = 0.0,
11941204
)
11951205

11961206
do = do.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK_HEADDIM)

test_triton_nsa.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
from math import ceil
22
import torch
3-
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend, round_up_multiple, pad_to_multiple
3+
4+
from native_sparse_attention_pytorch.native_sparse_attention import (
5+
create_sliding_mask,
6+
flex_attention
7+
)
8+
9+
from native_sparse_attention_pytorch.triton_native_sparse_attention import (
10+
native_sparse_attend,
11+
round_up_multiple,
12+
pad_to_multiple,
13+
)
414

515
import einx
616
from einops import rearrange, einsum, repeat
@@ -10,6 +20,9 @@
1020
def exists(v):
1121
return v is not None
1222

23+
def default(v, d):
24+
return v if exists(v) else d
25+
1326
def abs_diff(x, y):
1427
return (x - y).abs().amax()
1528

@@ -21,12 +34,22 @@ def regular_attend(
2134
indices,
2235
mask,
2336
block_size,
37+
sliding_window_size = None,
2438
sel_scale = None,
25-
return_lse = False
39+
return_lse = False,
40+
return_sliding_window_out = False
2641
):
2742
q_heads, seq_len, kv_heads, device = q.shape[1], q.shape[-2], k.shape[1], q.device
2843
assert divisible_by(q_heads, kv_heads)
2944

45+
if return_sliding_window_out:
46+
kv_seq_len = k.shape[-2]
47+
assert seq_len == kv_seq_len
48+
49+
sliding_window_size = default(sliding_window_size, block_size)
50+
sliding_mask = create_sliding_mask(kv_seq_len, sliding_window_size)
51+
sliding_out = flex_attention(q, k, v, block_mask = sliding_mask, enable_gqa = True)
52+
3053
q, k, v = tuple(pad_to_multiple(t, block_size, dim = -2) for t in (q, k, v))
3154

3255
if exists(sel_scale):
@@ -97,6 +120,9 @@ def regular_attend(
97120

98121
out = out[..., :seq_len, :]
99122

123+
if return_sliding_window_out:
124+
out = (out, sliding_out)
125+
100126
if not return_lse:
101127
return out
102128

@@ -114,6 +140,7 @@ def regular_attend(
114140
kv_heads = 2
115141
fine_block_size = 16
116142
num_sel = 6
143+
fused_sliding_window = False
117144

118145
q = torch.randn(batch, q_heads, seq_len, 64).cuda()
119146
k = torch.randn(batch, kv_heads, seq_len, 64).cuda()
@@ -130,7 +157,11 @@ def regular_attend(
130157

131158
# regular forwards and backwards
132159

133-
out, rlse = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size, sel_scale = rsel_scale, return_lse = True)
160+
out, rlse = regular_attend(rq, rk, rv, indices, mask, block_size = fine_block_size, sel_scale = rsel_scale, return_lse = True, return_sliding_window_out = fused_sliding_window)
161+
162+
if fused_sliding_window:
163+
out = sum(out)
164+
134165
out.sum().backward()
135166

136167
# triton nsa forwards and backwards

0 commit comments

Comments
 (0)