Skip to content

Commit 6606b67

Browse files
committed
more setup for testing
1 parent 420bed4 commit 6606b67

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

test_triton_nsa.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from native_sparse_attention_pytorch.triton_native_sparse_attention import native_sparse_attend
33

44
import einx
5-
from einops import rearrange, einsum
5+
from einops import rearrange, einsum, repeat
66

77
assert torch.cuda.is_available()
88

@@ -33,7 +33,8 @@ def regular_attend(
3333

3434
# rest of the indices
3535

36-
has_sel_kv_blocks = indices.shape[-1] > 0
36+
num_sel_kv_blocks = indices.shape[-1]
37+
has_sel_kv_blocks = num_sel_kv_blocks > 0
3738

3839
if has_sel_kv_blocks:
3940
bk, bv = tuple(rearrange(t, 'b (h w) n d -> b h w n d', h = kv_heads) for t in (k, v))
@@ -43,14 +44,33 @@ def regular_attend(
4344
q = rearrange(q, 'b (h w) n d -> b h (w n) d', h = q_heads)
4445
bsim = einsum(q, sel_bk, 'b h i d, b h i j d -> b h i j') * scale
4546

47+
bsim = rearrange(bsim, 'b h (w i) (sel j) -> b h w i sel j', sel = num_sel_kv_blocks, i = fine_block_size)
48+
49+
mask = rearrange(mask, 'b h (w i) sel -> b h w i sel', i = fine_block_size)
50+
bsim = torch.where(mask[..., None], bsim, -torch.finfo(bsim.dtype).max)
51+
52+
sim = rearrange(sim, 'b (h w) i j -> b h w i 1 j', h = q_heads)
53+
54+
sim = torch.cat((sim, bsim), dim = -2)
55+
sim = rearrange(sim, 'b h w i causal_and_sel j -> b h (w i) (causal_and_sel j)')
56+
57+
sel_bv = rearrange(sel_bv, 'b h (w i) j d -> b h w i j d', i = fine_block_size)
58+
59+
v = repeat(v, 'b (h w) j d -> b h w i j d', h = kv_heads, i = fine_block_size)
60+
v = torch.cat((v, sel_bv), dim = -2)
61+
v = rearrange(v, 'b h w i j d -> b h (w i) j d')
62+
4663
# attend
4764

4865
attn = sim.softmax(dim = -1)
4966

50-
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
67+
if has_sel_kv_blocks:
68+
out = einsum(attn, v, 'b h i j, b h i j d -> b h i d')
69+
else:
70+
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
5171

52-
if exists(block_size):
53-
out = rearrange(out, 'b (h w) n d -> b h (w n) d', w = w)
72+
if exists(block_size):
73+
out = rearrange(out, 'b (h w) n d -> b h (w n) d', w = w)
5474

5575
return out
5676

@@ -62,8 +82,8 @@ def regular_attend(
6282
k = torch.randn(1, 2, 512, 64).cuda()
6383
v = torch.randn(1, 2, 512, 64).cuda()
6484

65-
indices = torch.zeros(1, 2, 512, 1).long().cuda()
66-
mask = torch.zeros(1, 2, 512, 1).bool().cuda()
85+
indices = torch.zeros(1, 2, 512, 0).long().cuda()
86+
mask = torch.zeros(1, 2, 512, 0).bool().cuda()
6787

6888
# both regular and nsa pathways `r` and `n`
6989

0 commit comments

Comments
 (0)