Skip to content

Commit 16d9bbc

Browse files
committed
mask out the block diagonal for the importance score, as block causal is always included for fine attention
1 parent 36837b2 commit 16d9bbc

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,16 +392,25 @@ def forward(
392392
compress_seq_len = score_len * self.compress_block_size
393393

394394
if self.interpolated_importance_score:
395-
mask = importance_scores > 1e-10
396-
mask = repeat(mask, '... j -> ... (j block_size)', block_size = self.compress_block_size)
397395
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
398-
importance_scores = importance_scores.masked_fill(~mask, 0.)
399396
else:
400397
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
401398

402399
padding = fine_divisible_seq_len - compress_seq_len
400+
401+
fine_query_seq_len = importance_scores.shape[-2]
402+
fine_query_padding = fine_divisible_seq_len - importance_scores.shape[-2]
403+
403404
importance_scores = F.pad(importance_scores, (0, padding))
404405

406+
# mask out the diagonal since block causal is included by default for fine attending
407+
408+
block_causal_mask = torch.ones((num_fine_blocks,) * 2, device = device, dtype = torch.bool).tril(-1)
409+
block_causal_mask = repeat(block_causal_mask, 'i j -> (i n1) (j n2)', n1 = self.selection_block_size, n2 = self.selection_block_size)
410+
block_causal_mask = block_causal_mask[:fine_query_seq_len]
411+
412+
importance_scores = importance_scores.masked_fill(~block_causal_mask, 0.)
413+
405414
importance_scores = reduce(importance_scores, '... (j block_size) -> ... j', 'mean', block_size = self.selection_block_size)
406415

407416
# handle if number of total blocks is less than number to select for fine attention
@@ -411,6 +420,9 @@ def forward(
411420
fv = v
412421

413422
if has_selected_kv_for_fine_attn:
423+
424+
# get the top-n kv segments for fine attention
425+
414426
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
415427

416428
if self.use_diff_topk:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.44"
3+
version = "0.0.45"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)