Skip to content

Commit 8d0846a

Browse files
committed
make the fine flex block mask also aware of gqa
1 parent 298cfec commit 8d0846a

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def compress_mask(_, __, q_idx, kv_idx):
6767

6868
def create_fine_mask(seq_len, fine_block_size):
6969

70-
def inner(selected_block_indices: Tensor):
70+
def inner(selected_block_indices: Tensor, num_grouped_queries = 1):
7171
device = selected_block_indices.device
72-
batch, heads = selected_block_indices.shape[:2]
72+
batch, kv_heads = selected_block_indices.shape[:2]
7373

7474
one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool)
7575
one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True)
@@ -78,15 +78,16 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
7878

7979
compressed_q_idx = q_idx // fine_block_size
8080
compressed_kv_idx = kv_idx // fine_block_size
81+
kv_head_idx = h_idx // num_grouped_queries
8182

82-
is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx]
83+
is_selected = one_hot_selected_block_indices[b_idx, kv_head_idx, q_idx, compressed_kv_idx]
8384

8485
causal_mask = q_idx >= kv_idx
8586
block_diagonal = compressed_q_idx == compressed_kv_idx
8687

8788
return (causal_mask & (block_diagonal | is_selected))
8889

89-
block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
90+
block_mask = create_block_mask(fine_mask, B = batch, H = kv_heads * num_grouped_queries, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
9091
return block_mask
9192

9293
return inner
@@ -349,11 +350,9 @@ def forward(
349350
if exists(fine_selection_flex_mask):
350351
# flex attention for the selection for fine attention
351352

352-
fk, fv, selected_block_indices = tuple(repeat(t, 'b h ... -> b (h num_grouped_queries) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, selected_block_indices))
353+
fine_block_mask = fine_selection_flex_mask(selected_block_indices, num_grouped_queries = self.num_grouped_queries)
353354

354-
fine_block_mask = fine_selection_flex_mask(selected_block_indices)
355-
356-
fine_attn_out = flex_attention(fq, fk, fv, block_mask = fine_block_mask)
355+
fine_attn_out = flex_attention(fq, fk, fv, block_mask = fine_block_mask, enable_gqa = True)
357356

358357
else:
359358
fmask = selected_importance_values > 1e-10

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.31"
3+
version = "0.0.33"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)