Skip to content

Commit 582b844

Browse files
committed
redo with approach of using compressed similarities for interpolation, before remasking and normalizing for importance scores
1 parent 16d9bbc commit 582b844

File tree

3 files changed

+36
-28
lines changed

3 files changed

+36
-28
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def round_up_mult(n, mult):
113113
def divisible_by(num, den):
114114
return (num % den) == 0
115115

116+
def max_neg_value(t):
117+
return -torch.finfo(t.dtype).max
118+
116119
def pack_one_with_inverse(t, pattern):
117120
packed, ps = pack([t], pattern)
118121
def inverse(out):
@@ -142,7 +145,7 @@ def straight_through(t, target):
142145
def attend(
143146
q, k, v,
144147
mask = None,
145-
return_attn = False,
148+
return_sim = False,
146149
scale = None
147150
):
148151
scale = default(scale, q.shape[-1] ** -0.5)
@@ -154,7 +157,7 @@ def attend(
154157

155158
sim = einsum(q, k, 'b h qh i d, b h j d -> b h qh i j') * scale
156159

157-
mask_value = -torch.finfo(sim.dtype).max
160+
mask_value = max_neg_value(sim)
158161

159162
if exists(mask):
160163
sim = sim.masked_fill(~mask, mask_value)
@@ -165,12 +168,12 @@ def attend(
165168

166169
attn_out = rearrange(attn_out, 'b h qh ... -> b (h qh) ...')
167170

168-
if not return_attn:
171+
if not return_sim:
169172
return attn_out
170173

171-
attn = rearrange(attn, 'b h qh ... -> b (h qh) ...')
174+
sim = rearrange(sim, 'b h qh ... -> b (h qh) ...')
172175

173-
return attn_out, attn
176+
return attn_out, sim
174177

175178
# classes
176179

@@ -360,17 +363,17 @@ def forward(
360363

361364
cmask = einx.less('j, i -> i j', ck_seq, cq_seq)
362365

363-
compressed_attn_out, cattn = attend(cq, ck, cv, mask = cmask, return_attn = True)
366+
compressed_attn_out, csim = attend(cq, ck, cv, mask = cmask, return_sim = True)
364367

365368
# for 2. and 3., will give them relative positions with rotary - compressed needs to be handled separately (even if they already have intra block absolute positions)
366369

367370
rotated_q, rotated_k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
368371

369372
# 2. fine attention over selected based on compressed attention logits - variables prepended with `f` stands for the fine attention pathway
370373

371-
importance_scores = cattn[..., num_mem_compress_kv:]
374+
importance_scores = csim[..., num_mem_compress_kv:]
372375

373-
num_selected = min(self.num_selected_blocks, importance_scores.shape[-1])
376+
num_selected = min(self.num_selected_blocks, num_compress_blocks)
374377
has_selected_kv_for_fine_attn = num_selected > 0
375378

376379
# maybe average the compressed attention across each grouped queries (per key / values)
@@ -386,32 +389,37 @@ def forward(
386389
# cannot parse their equation, so will just improvise
387390
# first we expand all the compressed scores to the full sequence length, then average within each fine / selection block size - pad on the right to 0s, which should be fine as sliding window convers the local anyways
388391

389-
if has_selected_kv_for_fine_attn and self.compress_block_size != self.selection_block_size:
392+
if has_selected_kv_for_fine_attn:
390393

391-
score_len = importance_scores.shape[-1]
392-
compress_seq_len = score_len * self.compress_block_size
394+
if self.compress_block_size != self.selection_block_size:
393395

394-
if self.interpolated_importance_score:
395-
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
396-
else:
397-
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
396+
compress_seq_len = num_compress_blocks * self.compress_block_size
397+
398+
if self.interpolated_importance_score:
399+
importance_scores = interpolate_1d(importance_scores, compress_seq_len)
400+
else:
401+
importance_scores = repeat(importance_scores, '... j -> ... (j block_size)', block_size = self.compress_block_size)
402+
403+
padding = fine_divisible_seq_len - compress_seq_len
398404

399-
padding = fine_divisible_seq_len - compress_seq_len
405+
fine_query_seq_len = importance_scores.shape[-2]
406+
fine_query_padding = fine_divisible_seq_len - importance_scores.shape[-2]
400407

401-
fine_query_seq_len = importance_scores.shape[-2]
402-
fine_query_padding = fine_divisible_seq_len - importance_scores.shape[-2]
408+
importance_scores = F.pad(importance_scores, (0, padding))
403409

404-
importance_scores = F.pad(importance_scores, (0, padding))
410+
# mask out the diagonal since block causal is included by default for fine attending
405411

406-
# mask out the diagonal since block causal is included by default for fine attending
412+
block_causal_mask = torch.ones((num_fine_blocks,) * 2, device = device, dtype = torch.bool).tril(-1)
413+
block_causal_mask = repeat(block_causal_mask, 'i j -> (i n1) (j n2)', n1 = self.selection_block_size, n2 = self.selection_block_size)
414+
block_causal_mask = block_causal_mask[:fine_query_seq_len]
407415

408-
block_causal_mask = torch.ones((num_fine_blocks,) * 2, device = device, dtype = torch.bool).tril(-1)
409-
block_causal_mask = repeat(block_causal_mask, 'i j -> (i n1) (j n2)', n1 = self.selection_block_size, n2 = self.selection_block_size)
410-
block_causal_mask = block_causal_mask[:fine_query_seq_len]
416+
importance_scores = importance_scores.masked_fill(~block_causal_mask, max_neg_value(csim))
411417

412-
importance_scores = importance_scores.masked_fill(~block_causal_mask, 0.)
418+
importance_scores = reduce(importance_scores, '... (j block_size) -> ... j', 'mean', block_size = self.selection_block_size)
413419

414-
importance_scores = reduce(importance_scores, '... (j block_size) -> ... j', 'mean', block_size = self.selection_block_size)
420+
importance_scores = F.pad(importance_scores, (1, 0), value = -1e3)
421+
importance_scores = importance_scores.softmax(dim = -1)
422+
importance_scores = importance_scores[..., 1:]
415423

416424
# handle if number of total blocks is less than number to select for fine attention
417425

@@ -496,7 +504,7 @@ def forward(
496504

497505
fsim = einsum(fq, fk, 'b h qh i d, b h i j d -> b h qh i j') * self.scale
498506

499-
mask_value = -torch.finfo(fsim.dtype).max
507+
mask_value = max_neg_value(fsim)
500508

501509
fsim = fsim.masked_fill(~fmask, mask_value)
502510

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.45"
3+
version = "0.0.47"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
COMPRESS_BLOCK_SIZE = 64
4242

4343
FINE_BLOCK_SIZE = 32
44-
NUM_FINE_SELECTED = 0
44+
NUM_FINE_SELECTED = 1
4545

4646
INTERPOLATED_IMPORTANCE_SCORE = False
4747
USE_DIFF_TOPK = True

0 commit comments

Comments
 (0)