Skip to content

Commit e9476ec

Browse files
committed
ready to be compared with full attention.
1 parent f32ed38 commit e9476ec

File tree

4 files changed

+31
-14
lines changed

4 files changed

+31
-14
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def __init__(
189189
norm = True,
190190
use_diff_topk = False,
191191
interpolated_importance_score = False,
192+
query_heads_share_selected_kv = True, # if set to True, importance score is averaged across query heads to select top-n buckets of kv per kv head - but can be set to False for each query head within a group to look at different sets of kv buckets. will be more memory and compute of course
192193
compress_mlp: Module | None = None,
193194
compress_mlp_expand_factor = 1.,
194195
strategy_combine_mlp: Module | None = None
@@ -272,6 +273,8 @@ def __init__(
272273

273274
self.interpolated_importance_score = interpolated_importance_score # in the case fine block size < compressed block size, will weigh space better when selecting
274275

276+
self.query_heads_share_selected_kv = query_heads_share_selected_kv
277+
275278
self.selection_block_size = selection_block_size
276279

277280
assert num_selected_blocks > 0
@@ -363,9 +366,14 @@ def forward(
363366

364367
importance_scores = cattn[..., num_mem_compress_kv:]
365368

366-
# for gqa, we will average the compressed attention across each grouped queries (per key / values)
369+
# maybe average the compressed attention across each grouped queries (per key / values)
370+
371+
if self.query_heads_share_selected_kv:
372+
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
367373

368-
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
374+
fine_num_grouped_queries = self.num_grouped_queries
375+
else:
376+
fine_num_grouped_queries = 1
369377

370378
# handle if compress block size does not equal to the fine block size
371379
# cannot parse their equation, so will just improvise
@@ -400,12 +408,12 @@ def forward(
400408
if self.use_diff_topk:
401409
gates = straight_through(selected_importance_values, 1.)
402410
gates = gates.cumprod(dim = -1)[..., -1]
403-
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = self.num_grouped_queries)
411+
gates = repeat(gates, 'b h ... -> b (h qh) ...', qh = fine_num_grouped_queries)
404412

405413
if exists(fine_selection_flex_mask):
406414
# flex attention for the selection for fine attention
407415

408-
fine_block_mask = fine_selection_flex_mask(selected_block_indices, num_grouped_queries = self.num_grouped_queries)
416+
fine_block_mask = fine_selection_flex_mask(selected_block_indices, num_grouped_queries = fine_num_grouped_queries)
409417

410418
fine_attn_out = flex_attention(fq, fk, fv, block_mask = fine_block_mask, enable_gqa = True)
411419

@@ -428,13 +436,13 @@ def forward(
428436
# handle block causal diagonal in the diagram, but run experiments without to see
429437

430438
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
431-
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = self.kv_heads)
439+
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = selected_block_indices.shape[1])
432440
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
433441

434442
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
435443

436444
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
437-
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = self.kv_heads)
445+
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = fmask.shape[1])
438446

439447
fmask = cat((fmask, causal_mask), dim = -2)
440448
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
@@ -446,8 +454,12 @@ def forward(
446454

447455
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
448456

449-
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
450-
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
457+
if self.query_heads_share_selected_kv:
458+
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
459+
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
460+
else:
461+
fk = repeat(fk, 'b h w j d -> b (h qh) i w j d', i = selected_block_indices.shape[2], qh = self.num_grouped_queries)
462+
fv = repeat(fv, 'b h w j d -> b (h qh) i w j d', i = selected_block_indices.shape[2], qh = self.num_grouped_queries)
451463

452464
selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])
453465

@@ -460,7 +472,7 @@ def forward(
460472

461473
fmask = rearrange(fmask, 'b h ... -> b h 1 ...')
462474

463-
fq = rearrange(fq, 'b (h qh) ... -> b h qh ...', qh = self.num_grouped_queries)
475+
fq = rearrange(fq, 'b (h qh) ... -> b h qh ...', qh = fine_num_grouped_queries)
464476

465477
fsim = einsum(fq, fk, 'b h qh i d, b h i j d -> b h qh i j') * self.scale
466478

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.40"
3+
version = "0.0.41"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_sparse_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
1111
@pytest.mark.parametrize('kv_heads', (8, 4))
1212
@pytest.mark.parametrize('selection_block_size', (8, 4, 2))
13+
@pytest.mark.parametrize('query_heads_share_selected_kv', (False, True))
1314
def test_sparse_attn(
1415
use_diff_topk,
1516
seq_len,
1617
kv_heads,
17-
selection_block_size
18+
selection_block_size,
19+
query_heads_share_selected_kv
1820
):
1921
attn = SparseAttention(
2022
dim = 512,
@@ -25,7 +27,8 @@ def test_sparse_attn(
2527
compress_block_size = 4,
2628
selection_block_size = selection_block_size,
2729
num_selected_blocks = 2,
28-
use_diff_topk = use_diff_topk
30+
use_diff_topk = use_diff_topk,
31+
query_heads_share_selected_kv = query_heads_share_selected_kv
2932
)
3033

3134
tokens = torch.randn(2, seq_len, 512)

train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
SEQ_LEN = 256
3131

3232
USE_SPARSE_ATTN = True
33-
USE_FLEX_FOR_FINE_SELECTION = True # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
33+
USE_FLEX_FOR_FINE_SELECTION = True # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
34+
QUERY_HEADS_SHARE_SELECTION = False # if set to False, each query head can look at a different segment of their corresponding key / value head in GQA
3435

3536
# experiment related
3637

@@ -117,7 +118,8 @@ def base_decoding(
117118
selection_block_size = 32,
118119
num_selected_blocks = 2,
119120
use_diff_topk = True,
120-
interpolated_importance_score = True
121+
interpolated_importance_score = True,
122+
query_heads_share_selected_kv = QUERY_HEADS_SHARE_SELECTION
121123
)
122124
).cuda()
123125

0 commit comments

Comments
 (0)