Skip to content

Commit 6aa5fd8

Browse files
committed
deviate from the paper and allow for interpolation of the compressed scores for better selected fine blocks, when compress block size > fine block size
1 parent 14beb73 commit 6aa5fd8

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# einstein notation
1616

1717
import einx
18-
from einops import einsum, repeat, rearrange, reduce
18+
from einops import einsum, repeat, rearrange, reduce, pack, unpack
1919
from einops.layers.torch import Rearrange
2020

2121
# b - batch
@@ -109,13 +109,27 @@ def round_up_mult(n, mult):
109109
def divisible_by(num, den):
110110
return (num % den) == 0
111111

112+
def pack_one_with_inverse(t, pattern):
113+
packed, ps = pack([t], pattern)
114+
def inverse(out):
115+
return unpack(out, ps, pattern)[0]
116+
117+
return packed, inverse
118+
112119
# tensor helpers
113120

114121
def pad_at_dim(t, pad, dim = -1, value = 0.):
115122
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
116123
zeros = ((0, 0) * dims_from_right)
117124
return F.pad(t, (*zeros, *pad), value = value)
118125

126+
def interpolate_1d(x, length, mode = 'bilinear'):
127+
x, inverse_pack = pack_one_with_inverse(x, '* n')
128+
x = rearrange(x, 'b n -> b 1 n 1')
129+
x = F.interpolate(x, (length, 1), mode = mode)
130+
x = rearrange(x, 'b 1 n 1 -> b n')
131+
return inverse_pack(x)
132+
119133
def straight_through(t, target):
120134
return t + (target - t).detach()
121135

@@ -135,6 +149,7 @@ def __init__(
135149
num_compressed_mem_kv = 4,
136150
norm = True,
137151
use_diff_topk = False,
152+
interpolated_importance_score = False,
138153
compress_mlp: Module | None = None,
139154
compress_mlp_expand_factor = 1.,
140155
strategy_combine_mlp: Module | None = None
@@ -216,6 +231,8 @@ def __init__(
216231

217232
self.use_diff_topk = use_diff_topk
218233

234+
self.interpolated_importance_score = interpolated_importance_score # in the case fine block size < compressed block size, will weigh space better when selecting
235+
219236
self.selection_block_size = selection_block_size
220237

221238
assert num_selected_blocks > 0
@@ -326,10 +343,18 @@ def forward(
326343
# first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways
327344

328345
if self.compress_block_size != self.selection_block_size:
329-
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
330-
padding = fine_divisible_seq_len - importance_scores.shape[-1]
331346

347+
score_len = importance_scores.shape[-1]
348+
compress_seq_len = score_len * self.compress_block_size
349+
350+
if self.interpolated_importance_score:
351+
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
352+
else:
353+
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
354+
355+
padding = fine_divisible_seq_len - compress_seq_len
332356
importance_scores = F.pad(importance_scores, (0, padding))
357+
333358
importance_scores = reduce(importance_scores, '... (j block_size) -> ... j', 'mean', block_size = self.selection_block_size)
334359

335360
# handle if number of total blocks is less than number to select for fine attention

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.34"
3+
version = "0.0.35"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def base_decoding(
104104
compress_block_size = 32,
105105
selection_block_size = 32,
106106
num_selected_blocks = 2,
107-
use_diff_topk = False
107+
use_diff_topk = False,
108+
interpolated_importance_score = True
108109
)
109110
).cuda()
110111

0 commit comments

Comments
 (0)