Skip to content

Commit 6eb45de

Browse files
committed
use enable_gqa for flex attention for the sliding windows branch
1 parent 0d96765 commit 6eb45de

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,11 @@ def forward(
446446
sk = rotated_k
447447
sv = v
448448

449-
sk, sv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (sk, sv))
450-
451449
if exists(sliding_window_flex_mask):
452-
sliding_window_attn_out = flex_attention(sq, sk, sv, block_mask = sliding_window_flex_mask)
450+
sliding_window_attn_out = flex_attention(sq, sk, sv, block_mask = sliding_window_flex_mask, enable_gqa = True)
453451
else:
452+
sk, sv = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (sk, sv))
453+
454454
sliding_window_attn_out = self.sliding_window(sq, sk, sv)
455455

456456
# combine strategies

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.29"
3+
version = "0.0.30"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)