Skip to content

Commit fd0c756

Browse files
committed
complete the fine attention masking with flex attention, not wired up
1 parent 14c90bd commit fd0c756

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,24 @@ def compress_mask(_, __, q_idx, kv_idx):
6767

6868

6969
def create_fine_mask(selected_block_indices: Tensor, seq_len, fine_block_size):
70+
device = selected_block_indices.device
7071
batch, heads = selected_block_indices.shape[:2]
7172

73+
one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool)
74+
one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True)
75+
7276
def fine_mask(b_idx, h_idx, q_idx, kv_idx):
73-
selected_indices = selected_block_indices[b_idx, h_idx]
7477

75-
# todo - fill in logic for creating the selected kv ranges per query
78+
compressed_q_idx = q_idx // fine_block_size
79+
compressed_kv_idx = kv_idx // fine_block_size
80+
81+
block_causal_mask = compressed_q_idx > compressed_kv_idx
82+
is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx]
7683

7784
causal_mask = q_idx >= kv_idx
78-
block_diagonal = (q_idx // fine_block_size) == (kv_idx // fine_block_size)
85+
block_diagonal = compressed_q_idx == compressed_kv_idx
7986

80-
return (block_diagonal & causal_mask)
87+
return (causal_mask & block_diagonal) | (block_causal_mask & is_selected)
8188

8289
block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
8390
return block_mask

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.23"
3+
version = "0.0.24"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)