Skip to content

Commit 10bcfb9

Browse files
committed
last commit for the day, should be ready for experiments tomorrow
1 parent fd0c756 commit 10bcfb9

File tree

4 files changed

+36
-22
lines changed

4 files changed

+36
-22
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from math import ceil
55

66
import torch
7-
from torch import nn, arange, stack, cat
7+
from torch import nn, arange, stack, cat, Tensor
88
import torch.nn.functional as F
99
from torch.nn import Module, ModuleList
1010

@@ -65,29 +65,32 @@ def compress_mask(_, __, q_idx, kv_idx):
6565
block_mask = create_block_mask(compress_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True)
6666
return block_mask
6767

68+
def create_fine_mask(seq_len, fine_block_size):
6869

69-
def create_fine_mask(selected_block_indices: Tensor, seq_len, fine_block_size):
70-
device = selected_block_indices.device
71-
batch, heads = selected_block_indices.shape[:2]
70+
def inner(selected_block_indices: Tensor):
71+
device = selected_block_indices.device
72+
batch, heads = selected_block_indices.shape[:2]
7273

73-
one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool)
74-
one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True)
74+
one_hot_selected_block_indices = torch.zeros((*selected_block_indices.shape[:-1], seq_len // fine_block_size), device = device, dtype = torch.bool)
75+
one_hot_selected_block_indices.scatter_(-1, selected_block_indices, True)
7576

76-
def fine_mask(b_idx, h_idx, q_idx, kv_idx):
77+
def fine_mask(b_idx, h_idx, q_idx, kv_idx):
7778

78-
compressed_q_idx = q_idx // fine_block_size
79-
compressed_kv_idx = kv_idx // fine_block_size
79+
compressed_q_idx = q_idx // fine_block_size
80+
compressed_kv_idx = kv_idx // fine_block_size
8081

81-
block_causal_mask = compressed_q_idx > compressed_kv_idx
82-
is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx]
82+
block_causal_mask = compressed_q_idx > compressed_kv_idx
83+
is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx]
8384

84-
causal_mask = q_idx >= kv_idx
85-
block_diagonal = compressed_q_idx == compressed_kv_idx
85+
causal_mask = q_idx >= kv_idx
86+
block_diagonal = compressed_q_idx == compressed_kv_idx
8687

87-
return (causal_mask & block_diagonal) | (block_causal_mask & is_selected)
88+
return (causal_mask & block_diagonal) | (block_causal_mask & is_selected)
8889

89-
block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
90-
return block_mask
90+
block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
91+
return block_mask
92+
93+
return inner
9194

9295
# helpers
9396

@@ -241,7 +244,8 @@ def __init__(
241244
def forward(
242245
self,
243246
inp,
244-
sliding_window_flex_mask = None
247+
sliding_window_flex_mask = None,
248+
fine_selection_flex_mask = None
245249
):
246250
batch, seq_len, scale, heads, device = *inp.shape[:2], self.scale, self.heads, inp.device
247251

native_sparse_attention_pytorch/transformer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
from native_sparse_attention_pytorch.native_sparse_attention import (
1212
SparseAttention,
13-
create_sliding_mask,
1413
create_compress_mask,
14+
create_fine_mask,
15+
create_sliding_mask,
1516
)
1617

1718
# flex attention
@@ -121,6 +122,7 @@ def __init__(
121122
ff_expansion_factor = 4.,
122123
use_sparse_attn = False,
123124
use_flex_sliding_window = False,
125+
use_flex_fine_selection = False,
124126
sparse_attn_kwargs: dict = dict(
125127
sliding_window_size = 32,
126128
compress_block_size = 4,
@@ -131,11 +133,12 @@ def __init__(
131133
super().__init__()
132134
self.token_emb = nn.Embedding(num_tokens, dim)
133135

134-
if use_flex_sliding_window:
136+
if use_flex_sliding_window or use_flex_fine_selection:
135137
assert exists(flex_attention), 'flex attention is not available on your current version of pytorch'
136138

137139
self.use_sparse_attn = use_sparse_attn
138-
self.use_flex_sliding_window = use_flex_sliding_window
140+
self.use_flex_sliding_window = use_sparse_attn & use_flex_sliding_window
141+
self.use_flex_fine_selection = use_sparse_attn & use_flex_fine_selection
139142

140143
layers = []
141144
for _ in range(depth):
@@ -186,11 +189,16 @@ def forward(
186189

187190
attn_kwargs = dict()
188191

189-
if not disable_flex and self.use_sparse_attn and self.use_flex_sliding_window:
192+
if not disable_flex and self.use_flex_sliding_window:
190193
attn_kwargs.update(
191194
sliding_window_flex_mask = create_sliding_mask(seq_len, self.attn_sliding_window_size)
192195
)
193196

197+
if not disable_flex and self.use_flex_fine_selection:
198+
attn_kwargs.udpate(
199+
fine_selection_flex_mask = create_fine_mask(seq_len, self.attn_fine_block_size)
200+
)
201+
194202
# layers
195203

196204
for attn, ff in self.layers:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.24"
3+
version = "0.0.25"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SEQ_LEN = 256
2525

2626
USE_SPARSE_ATTN = True
27+
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
2728

2829
# experiment related
2930

@@ -97,6 +98,7 @@ def base_decoding(
9798
kv_heads = 4,
9899
use_sparse_attn = USE_SPARSE_ATTN,
99100
use_flex_sliding_window = True,
101+
use_flex_fine_selection = USE_FLEX_FOR_FINE_SELECTION,
100102
sparse_attn_kwargs = dict(
101103
sliding_window_size = 32,
102104
compress_block_size = 32,

0 commit comments

Comments
 (0)