Skip to content

Commit 922a633

Browse files
committed
fix importance score not normalized at inference
1 parent c3aa352 commit 922a633

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,8 @@ def forward_inference(
496496
if self.query_heads_share_selected_kv:
497497
importance_scores = reduce(importance_scores, 'b (h grouped_queries) ... -> b h ...', 'mean', grouped_queries = self.num_grouped_queries)
498498

499+
importance_scores = importance_scores.softmax(dim = -1)
500+
499501
sel_scores, sel_indices = importance_scores.topk(num_selected, dim = -1)
500502

501503
fine_divisible_seq_len = round_up_mult(seq_len, self.selection_block_size)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.19"
3+
version = "0.1.20"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_sparse_attn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ def test_sparse_attn(
5151
assert tokens.shape == attended.shape
5252

5353
@pytest.mark.parametrize('seq_len', (2, 8, 16))
54-
def test_inference(seq_len):
54+
@pytest.mark.parametrize('num_selected_blocks', (0, 2))
55+
def test_inference(
56+
seq_len,
57+
num_selected_blocks
58+
):
5559

5660
attn = SparseAttention(
5761
dim = 512,
@@ -61,7 +65,7 @@ def test_inference(seq_len):
6165
sliding_window_size = 2,
6266
compress_block_size = 5,
6367
selection_block_size = 10,
64-
num_selected_blocks = 0
68+
num_selected_blocks = num_selected_blocks
6569
)
6670

6771
tokens = torch.randn(2, seq_len, 512)

0 commit comments

Comments
 (0)