Skip to content

Commit 6ee39a3

Browse files
committed
first handle when sequence length has number of compressed windows less than the number of selected blocks, but it still breaks for no blocks (seq length less than compress block size)
1 parent c391705 commit 6ee39a3

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def forward(
243243

244244
importance_scores = cattn[..., num_mem_compress_kv:]
245245

246-
selected_importance_values, selected_block_indices = importance_scores.topk(self.num_selected_blocks, dim = -1)
246+
topk = min(self.num_selected_blocks, importance_scores.shape[-1])
247+
248+
selected_importance_values, selected_block_indices = importance_scores.topk(topk, dim = -1)
247249

248250
if self.use_diff_topk:
249251
gates = selected_importance_values + (1. - selected_importance_values).detach()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.15"
3+
version = "0.0.16"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_sparse_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from native_sparse_attention_pytorch import SparseAttention
88

99
@pytest.mark.parametrize('use_diff_topk', (False, True))
10+
@pytest.mark.parametrize('seq_len', (4, 31, 32, 120))
1011
def test_sparse_attn(
11-
use_diff_topk
12+
use_diff_topk,
13+
seq_len
1214
):
1315
attn = SparseAttention(
1416
dim = 512,
@@ -21,7 +23,7 @@ def test_sparse_attn(
2123
use_diff_topk = use_diff_topk
2224
)
2325

24-
tokens = torch.randn(2, 31, 512)
26+
tokens = torch.randn(2, seq_len, 512)
2527

2628
attended = attn(tokens)
2729

0 commit comments

Comments
 (0)