Skip to content

Commit ba24fbe

Browse files
committed
handle no compressed blocks edge case
1 parent 6ee39a3 commit ba24fbe

File tree

4 files changed

+74
-50
lines changed

4 files changed

+74
-50
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,13 @@ To record some of your experiments, just invoke `wandb login` first before modif
5959
url = {https://api.semanticscholar.org/CorpusID:276408911}
6060
}
6161
```
62+
63+
```bibtex
64+
@inproceedings{Keles2022OnTC,
65+
title = {On The Computational Complexity of Self-Attention},
66+
author = {Feyza Duman Keles and Pruthuvi Maheshakya Wijewardena and Chinmay Hegde},
67+
booktitle = {International Conference on Algorithmic Learning Theory},
68+
year = {2022},
69+
url = {https://api.semanticscholar.org/CorpusID:252198880}
70+
}
71+
```

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -243,83 +243,97 @@ def forward(
243243

244244
importance_scores = cattn[..., num_mem_compress_kv:]
245245

246-
topk = min(self.num_selected_blocks, importance_scores.shape[-1])
247-
248-
selected_importance_values, selected_block_indices = importance_scores.topk(topk, dim = -1)
249-
250-
if self.use_diff_topk:
251-
gates = selected_importance_values + (1. - selected_importance_values).detach()
252-
253-
fmask = selected_importance_values > 1e-10
246+
num_selected = min(self.num_selected_blocks, importance_scores.shape[-1])
254247

255248
fq = rotated_q
256249
fk = rotated_k
257250
fv = v
258251

259-
if seq_len < fine_divisible_seq_len:
260-
remainder = fine_divisible_seq_len - seq_len
261-
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
262-
fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2)
263-
fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2)
252+
if num_selected > 0:
253+
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
254+
255+
if self.use_diff_topk:
256+
gates = selected_importance_values + (1. - selected_importance_values).detach()
264257

265-
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
258+
fmask = selected_importance_values > 1e-10
266259

267-
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
260+
if seq_len < fine_divisible_seq_len:
261+
remainder = fine_divisible_seq_len - seq_len
262+
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
263+
fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2)
264+
fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2)
268265

269-
if self.use_diff_topk:
270-
gates = pad_at_dim(gates, (0, remainder), value = 1., dim = -2)
266+
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
271267

272-
# handle block causal diagonal in the diagram, but run experiments without to see
268+
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
273269

274-
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
275-
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = heads)
276-
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
270+
if self.use_diff_topk:
271+
gates = pad_at_dim(gates, (0, remainder), value = 1., dim = -2)
277272

278-
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
273+
# handle block causal diagonal in the diagram, but run experiments without to see
279274

280-
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
281-
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = heads)
275+
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
276+
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = heads)
277+
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
282278

283-
fmask = cat((fmask, causal_mask), dim = -2)
284-
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
279+
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
285280

286-
# select out the spatial crops of keys / values for fine attention
281+
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
282+
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = heads)
287283

288-
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
289-
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
284+
fmask = cat((fmask, causal_mask), dim = -2)
285+
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
290286

291-
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
287+
# select out the spatial crops of keys / values for fine attention
292288

293-
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
294-
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
289+
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
290+
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
295291

296-
selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])
292+
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
297293

298-
fk = fk.gather(3, selected_block_indices)
299-
fv = fv.gather(3, selected_block_indices)
294+
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
295+
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
300296

301-
# handle maybe gating
297+
selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])
302298

303-
if self.use_diff_topk:
304-
gates = F.pad(gates, (0, 1), value = 1.)
299+
fk = fk.gather(3, selected_block_indices)
300+
fv = fv.gather(3, selected_block_indices)
305301

306-
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
307-
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)
302+
# handle maybe gating
303+
304+
if self.use_diff_topk:
305+
gates = F.pad(gates, (0, 1), value = 1.)
308306

309-
fk = rearrange(fk, 'b h i w j d -> b h i (w j) d')
310-
fv = rearrange(fv, 'b h i w j d -> b h i (w j) d')
307+
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
308+
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)
311309

312-
# fine attention
310+
fk = rearrange(fk, 'b h i w j d -> b h i (w j) d')
311+
fv = rearrange(fv, 'b h i w j d -> b h i (w j) d')
312+
313+
# fine attention
314+
315+
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
316+
317+
fsim = fsim.masked_fill(~fmask, mask_value)
318+
319+
fattn = fsim.softmax(dim = -1)
320+
321+
fine_attn_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
322+
323+
fine_attn_out = fine_attn_out[..., :seq_len, :]
324+
else:
325+
# if only first block, just do a simple block causal
313326

314-
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
327+
seq_len = fk.shape[-2]
328+
fmask = causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).tril()
315329

316-
fsim = fsim.masked_fill(~fmask, mask_value)
330+
fsim = einsum(fq, fk, 'b h i d, b h j d -> b h i j') * self.scale
317331

318-
fattn = fsim.softmax(dim = -1)
332+
fsim = fsim.masked_fill(~fmask, mask_value)
319333

320-
fine_attn_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
334+
fattn = fsim.softmax(dim = -1)
321335

322-
fine_attn_out = fine_attn_out[..., :seq_len, :]
336+
fine_attn_out = einsum(fattn, fv, 'b h i j, b h j d -> b h i d')
323337

324338
# 3. overlapping sliding window, this is unsurprising and expected
325339

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.16"
3+
version = "0.0.17"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_sparse_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from native_sparse_attention_pytorch import SparseAttention
88

99
@pytest.mark.parametrize('use_diff_topk', (False, True))
10-
@pytest.mark.parametrize('seq_len', (4, 31, 32, 120))
10+
@pytest.mark.parametrize('seq_len', (1, 4, 31, 32, 120))
1111
def test_sparse_attn(
1212
use_diff_topk,
1313
seq_len

0 commit comments

Comments
 (0)