Skip to content

Commit 420bed4

Browse files
committed
last commit for the day
1 parent 338ffb3 commit 420bed4

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

test_triton_nsa.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,36 @@ def regular_attend(
1515
mask,
1616
block_size = None,
1717
):
18+
q_heads, kv_heads = q.shape[1], k.shape[1]
19+
1820
if exists(block_size):
1921
w = q.shape[-2] // block_size
2022
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b (h w) n d', n = block_size) for t in (q, k, v))
2123

2224
seq_len, device = q.shape[-2], q.device
2325
scale = q.shape[-1] ** -0.5
26+
q = q * scale
27+
28+
# block causal diagonal
2429

25-
sim = einsum(q, k, 'b h i d, b h j d -> b h i j') * scale
30+
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
2631
causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)
2732
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
33+
34+
# rest of the indices
35+
36+
has_sel_kv_blocks = indices.shape[-1] > 0
37+
38+
if has_sel_kv_blocks:
39+
bk, bv = tuple(rearrange(t, 'b (h w) n d -> b h w n d', h = kv_heads) for t in (k, v))
40+
sel_bk = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bk, indices)
41+
sel_bv = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bv, indices)
42+
43+
q = rearrange(q, 'b (h w) n d -> b h (w n) d', h = q_heads)
44+
bsim = einsum(q, sel_bk, 'b h i d, b h i j d -> b h i j') * scale
45+
46+
# attend
47+
2848
attn = sim.softmax(dim = -1)
2949

3050
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')

0 commit comments

Comments
 (0)