Skip to content

Commit dbcf080

Browse files
committed
prepare for gqa with nsa
1 parent 8ac5b7a commit dbcf080

File tree

4 files changed

+60
-39
lines changed

4 files changed

+60
-39
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from native_sparse_attention_pytorch.native_sparse_attention import (
22
SparseAttention
33
)
4+
5+
from native_sparse_attention_pytorch.triton_native_sparse_attention import (
6+
native_sparse_attend
7+
)

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,7 @@ def forward(
451451
fq, fk, fv,
452452
self.selection_block_size,
453453
selected_block_indices,
454-
fmask,
455-
fine_num_grouped_queries
454+
fmask
456455
)
457456

458457
elif exists(fine_selection_flex_mask):

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from torch import Tensor
1010

11-
from einops import repeat, rearrange
11+
from einops import repeat, rearrange, reduce
1212

1313
def exists(v):
1414
return v is not None
@@ -1044,12 +1044,12 @@ def flash_attn_backward(
10441044
dq_accum.stride(0),
10451045
dq_accum.stride(1),
10461046
dq_accum.stride(2),
1047-
dk.stride(0),
1048-
dk.stride(1),
1049-
dk.stride(2),
1050-
dv.stride(0),
1051-
dv.stride(1),
1052-
dv.stride(2),
1047+
dk_accum.stride(0),
1048+
dk_accum.stride(1),
1049+
dk_accum.stride(2),
1050+
dv_accum.stride(0),
1051+
dv_accum.stride(1),
1052+
dv_accum.stride(2),
10531053
kv_block_indices.stride(0),
10541054
kv_block_indices.stride(1),
10551055
kv_block_indices.stride(2),
@@ -1094,10 +1094,15 @@ def forward(
10941094
block_size,
10951095
selected_block_indices,
10961096
fmask,
1097-
num_grouped_queries
10981097
):
10991098
dtype = fq.dtype
11001099

1100+
q_heads, kv_heads = fq.shape[1], fk.shape[1]
1101+
assert divisible_by(q_heads, kv_heads)
1102+
head_groups = q_heads // kv_heads
1103+
1104+
fk, fv, selected_block_indices, fmask = tuple(repeat(t, 'b h ... -> b (h g) ...', g = head_groups) for t in (fk, fv, selected_block_indices, fmask))
1105+
11011106
fq, fk, fv = tuple(t.half() for t in (fq, fk, fv))
11021107

11031108
out, lse = flash_attn_forward(
@@ -1108,23 +1113,29 @@ def forward(
11081113
)
11091114

11101115
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)
1111-
ctx._saved_variables = (block_size,)
1116+
1117+
ctx._saved_variables = (
1118+
block_size,
1119+
head_groups
1120+
)
11121121

11131122
return out.type(dtype)
11141123

11151124
@classmethod
11161125
def backward(self, ctx, do):
1126+
device = do.device
11171127

11181128
q, k, v, sel_block_indices, mask, out, lse = ctx.saved_tensors
11191129

11201130
(
11211131
block_size,
1132+
head_groups
11221133
) = ctx._saved_variables
11231134

11241135
do = do.half()
1125-
dq = torch.zeros_like(q)
1126-
dk = torch.zeros_like(k)
1127-
dv = torch.zeros_like(v)
1136+
dq = torch.zeros(q.shape, dtype = torch.float32, device = device)
1137+
dk = torch.zeros(k.shape, dtype = torch.float32, device = device)
1138+
dv = torch.zeros(v.shape, dtype = torch.float32, device = device)
11281139

11291140
flash_attn_backward(
11301141
do, q, k, v,
@@ -1133,6 +1144,8 @@ def backward(self, ctx, do):
11331144
block_size = block_size
11341145
)
11351146

1147+
dk, dv = tuple(reduce(t, 'b (h g) ... -> b h ...', 'sum', g = head_groups) for t in (dk, dv))
1148+
11361149
return dq, dk, dv, None, None, None, None
11371150

11381151
native_sparse_attend = NSA.apply

test_triton_nsa.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,34 @@
99
def exists(v):
1010
return v is not None
1111

12+
def divisible_by(num, den):
13+
return (num % den) == 0
14+
1215
def regular_attend(
1316
q, k, v,
1417
indices,
1518
mask,
16-
block_size = None,
19+
block_size,
1720
):
18-
q_heads, kv_heads = q.shape[1], k.shape[1]
21+
q_heads, seq_len, kv_heads, device = q.shape[1], q.shape[-2], k.shape[1], q.device
22+
assert divisible_by(q_heads, kv_heads)
23+
24+
g = q_heads // kv_heads # `g` stands for `g`roups of query heads per kv head
25+
26+
assert divisible_by(seq_len, block_size)
27+
w = seq_len // block_size
1928

20-
if exists(block_size):
21-
w = q.shape[-2] // block_size
22-
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b (h w) n d', n = block_size) for t in (q, k, v))
29+
q, k, v = tuple(rearrange(t, 'b h (w n) d -> b h w n d', n = block_size) for t in (q, k, v))
2330

24-
seq_len, device = q.shape[-2], q.device
2531
scale = q.shape[-1] ** -0.5
2632
q = q * scale
2733

34+
q = rearrange(q, 'b (h g) ... -> b h g ...', g = g)
35+
2836
# block causal diagonal
2937

30-
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
31-
causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)
38+
sim = einsum(q, k, 'b h g w i d, b h w j d -> b h g w i j')
39+
causal_mask = torch.ones((block_size, block_size), device = device, dtype = torch.bool).triu(1)
3240
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
3341

3442
# rest of the indices
@@ -37,48 +45,45 @@ def regular_attend(
3745
has_sel_kv_blocks = num_sel_kv_blocks > 0
3846

3947
if has_sel_kv_blocks:
40-
bk, bv = tuple(rearrange(t, 'b (h w) n d -> b h w n d', h = kv_heads) for t in (k, v))
48+
bk, bv = k, v
4149
sel_bk = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bk, indices)
4250
sel_bv = einx.get_at('b h [w] n d, b h i sel -> b h i (sel n) d', bv, indices)
4351

44-
q = rearrange(q, 'b (h w) n d -> b h (w n) d', h = q_heads)
45-
bsim = einsum(q, sel_bk, 'b h i d, b h i j d -> b h i j')
52+
q = rearrange(q, 'b h g w n d -> b h g (w n) d')
53+
bsim = einsum(q, sel_bk, 'b h g i d, b h i j d -> b h g i j')
4654

47-
bsim = rearrange(bsim, 'b h (w i) (sel j) -> b h w i sel j', sel = num_sel_kv_blocks, i = fine_block_size)
55+
bsim = rearrange(bsim, 'b h g (w i) (sel j) -> b h g w i sel j', sel = num_sel_kv_blocks, i = fine_block_size)
4856

49-
mask = rearrange(mask, 'b h (w i) sel -> b h w i sel', i = fine_block_size)
57+
mask = rearrange(mask, 'b h (w i) sel -> b h 1 w i sel', i = fine_block_size)
5058
bsim = torch.where(mask[..., None], bsim, -torch.finfo(bsim.dtype).max)
5159

52-
sim = rearrange(sim, 'b (h w) i j -> b h w i 1 j', h = q_heads)
60+
sim = rearrange(sim, 'b h g w i j -> b h g w i 1 j')
5361

5462
sim = torch.cat((sim, bsim), dim = -2)
55-
sim = rearrange(sim, 'b h w i causal_and_sel j -> b h (w i) (causal_and_sel j)')
63+
sim = rearrange(sim, 'b h g w i causal_and_sel j -> b h g w i (causal_and_sel j)')
5664

5765
sel_bv = rearrange(sel_bv, 'b h (w i) j d -> b h w i j d', i = fine_block_size)
5866

59-
v = repeat(v, 'b (h w) j d -> b h w i j d', h = kv_heads, i = fine_block_size)
67+
v = repeat(v, 'b h w j d -> b h w i j d', i = fine_block_size)
6068
v = torch.cat((v, sel_bv), dim = -2)
61-
v = rearrange(v, 'b h w i j d -> b h (w i) j d')
69+
v = rearrange(v, 'b h w i j d -> b h w i j d')
6270

6371
# attend
6472

6573
attn = sim.softmax(dim = -1)
6674

6775
if has_sel_kv_blocks:
68-
out = einsum(attn, v, 'b h i j, b h i j d -> b h i d')
76+
out = einsum(attn, v, 'b h g w i j, b h w i j d -> b h g w i d')
6977
else:
70-
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
71-
72-
if exists(block_size):
73-
out = rearrange(out, 'b (h w) n d -> b h (w n) d', w = w)
78+
out = einsum(attn, v, 'b h g w i j, b h j d -> b h g w i d')
7479

75-
return out
80+
return rearrange(out, 'b h g w n d -> b (h g) (w n) d')
7681

7782
# mock inputs
7883

7984
fine_block_size = 16
8085

81-
q = torch.randn(1, 2, 512, 64).cuda()
86+
q = torch.randn(1, 4, 512, 64).cuda()
8287
k = torch.randn(1, 2, 512, 64).cuda()
8388
v = torch.randn(1, 2, 512, 64).cuda()
8489

@@ -97,7 +102,7 @@ def regular_attend(
97102

98103
# triton nsa forwards and backwards
99104

100-
nsa_out = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask, 1)
105+
nsa_out = native_sparse_attend(nq, nk, nv, fine_block_size, indices, mask)
101106
nsa_out.sum().backward()
102107

103108
# asserts

0 commit comments

Comments
 (0)