Skip to content

Commit 1fc2163

Browse files
committed
start handling gqa, saving some compute during compression
1 parent ba24fbe commit 1fc2163

File tree

3 files changed

+54
-13
lines changed

3 files changed

+54
-13
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# einstein notation
1616

1717
import einx
18-
from einops import einsum, repeat, rearrange
18+
from einops import einsum, repeat, rearrange, reduce
1919
from einops.layers.torch import Rearrange
2020

2121
# b - batch
@@ -66,6 +66,9 @@ def round_down_mult(n, mult):
6666
def round_up_mult(n, mult):
6767
return ceil(n / mult) * mult
6868

69+
def divisible_by(num, den):
70+
return (num % den) == 0
71+
6972
def pad_at_dim(t, pad, dim = -1, value = 0.):
7073
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
7174
zeros = ((0, 0) * dims_from_right)
@@ -83,6 +86,7 @@ def __init__(
8386
compress_block_size,
8487
selection_block_size,
8588
num_selected_blocks,
89+
num_kv_heads = None,
8690
num_compressed_mem_kv = 4,
8791
norm = True,
8892
use_diff_topk = False,
@@ -91,12 +95,25 @@ def __init__(
9195
strategy_combine_mlp: Module | None = None
9296
):
9397
super().__init__()
98+
99+
# attention heads
100+
# handling gqa if `num_kv_heads` is set
101+
102+
num_kv_heads = default(num_kv_heads, heads)
103+
assert num_kv_heads <= heads and divisible_by(heads, num_kv_heads)
104+
94105
self.heads = heads
106+
self.num_kv_heads = num_kv_heads
107+
self.num_grouped_queries = heads // num_kv_heads
108+
109+
# scale
110+
95111
self.scale = dim_head ** -0.5
96112

97113
assert compress_block_size == selection_block_size, 'start off with compressed being equal to selection block sizes'
98114

99115
dim_inner = dim_head * heads
116+
dim_kv_inner = dim_head * num_kv_heads
100117

101118
self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
102119

@@ -106,7 +123,11 @@ def __init__(
106123

107124
# qkv
108125

109-
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
126+
qkv_split = (dim_inner, dim_kv_inner, dim_kv_inner)
127+
128+
self.to_qkv = nn.Linear(dim, sum(qkv_split), bias = False)
129+
130+
self.qkv_split = qkv_split
110131

111132
# sliding window strategy
112133

@@ -129,10 +150,10 @@ def __init__(
129150

130151
self.split_compress_window = Rearrange('b h (w n) d -> b h w n d', n = compress_block_size)
131152

132-
self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
153+
self.compress_mem_kv = nn.Parameter(torch.zeros(2, num_kv_heads, num_compressed_mem_kv, dim_head))
133154

134-
self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
135-
self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
155+
self.k_intrablock_positions = nn.Parameter(torch.zeros(num_kv_heads, compress_block_size, dim_head))
156+
self.v_intrablock_positions = nn.Parameter(torch.zeros(num_kv_heads, compress_block_size, dim_head))
136157

137158
if not exists(compress_mlp):
138159
compress_dim = compress_block_size * dim_head
@@ -168,7 +189,7 @@ def __init__(
168189

169190
# split and merging heads
170191

171-
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
192+
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
172193
self.merge_heads = Rearrange('b h n d -> b n (h d)')
173194

174195
# combining heads
@@ -194,7 +215,7 @@ def forward(
194215

195216
# queries, keys, values
196217

197-
q, k, v = self.to_qkv(inp).chunk(3, dim = -1)
218+
q, k, v = self.to_qkv(inp).split(self.qkv_split, dim = -1)
198219

199220
q, k, v = map(self.split_heads, (q, k, v))
200221

@@ -218,6 +239,8 @@ def forward(
218239
ck = cat((mem_ck, ck), dim = -2)
219240
cv = cat((mem_cv, cv), dim = -2)
220241

242+
ck, cv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (ck, cv))
243+
221244
csim = einsum(q, ck, 'b h i d, b h j d -> b h i j') * self.scale
222245

223246
cq_seq = arange(seq_len, device = device)
@@ -241,8 +264,13 @@ def forward(
241264

242265
# 2. fine attention over selected based on compressed attention logits
243266

267+
244268
importance_scores = cattn[..., num_mem_compress_kv:]
245269

270+
# for gqa, we will average the compressed attention across each grouped queries (per key / values)
271+
272+
importance_scores = reduce(importance_scores, 'b (grouped_queries h) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
273+
246274
num_selected = min(self.num_selected_blocks, importance_scores.shape[-1])
247275

248276
fq = rotated_q
@@ -273,13 +301,13 @@ def forward(
273301
# handle block causal diagonal in the diagram, but run experiments without to see
274302

275303
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
276-
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = heads)
304+
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = self.num_kv_heads)
277305
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
278306

279307
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
280308

281309
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
282-
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = heads)
310+
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = self.num_kv_heads)
283311

284312
fmask = cat((fmask, causal_mask), dim = -2)
285313
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
@@ -312,6 +340,8 @@ def forward(
312340

313341
# fine attention
314342

343+
fk, fv, fmask = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, fmask))
344+
315345
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
316346

317347
fsim = fsim.masked_fill(~fmask, mask_value)
@@ -327,6 +357,8 @@ def forward(
327357
seq_len = fk.shape[-2]
328358
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
329359

360+
fk, fv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv))
361+
330362
fsim = einsum(fq, fk, 'b h i d, b h j d -> b h i j') * self.scale
331363

332364
fsim = fsim.masked_fill(~fmask, mask_value)
@@ -337,10 +369,16 @@ def forward(
337369

338370
# 3. overlapping sliding window, this is unsurprising and expected
339371

372+
sq = rotated_q
373+
sk = rotated_k
374+
sv = v
375+
376+
sk, sv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (sk, sv))
377+
340378
if exists(sliding_window_flex_mask):
341-
sliding_window_attn_out = flex_attention(rotated_q, rotated_k, v, block_mask = sliding_window_flex_mask)
379+
sliding_window_attn_out = flex_attention(sq, sk, sv, block_mask = sliding_window_flex_mask)
342380
else:
343-
sliding_window_attn_out = self.sliding_window(rotated_q, rotated_k, v)
381+
sliding_window_attn_out = self.sliding_window(sq, sk, sv)
344382

345383
# combine strategies
346384

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.17"
3+
version = "0.0.18"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_sparse_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88

99
@pytest.mark.parametrize('use_diff_topk', (False, True))
1010
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
11+
@pytest.mark.parametrize('num_kv_heads', (8, 4))
1112
def test_sparse_attn(
1213
use_diff_topk,
13-
seq_len
14+
seq_len,
15+
num_kv_heads
1416
):
1517
attn = SparseAttention(
1618
dim = 512,
1719
dim_head = 64,
1820
heads = 8,
21+
num_kv_heads = num_kv_heads,
1922
sliding_window_size = 2,
2023
compress_block_size = 4,
2124
selection_block_size = 4,

0 commit comments

Comments
 (0)