Skip to content

Commit 6ee9078

Browse files
committed
wire up flex fine selected attention and make sure it runs
1 parent 115279f commit 6ee9078

File tree

4 files changed

+57
-47
lines changed

4 files changed

+57
-47
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,12 @@ def fine_mask(b_idx, h_idx, q_idx, kv_idx):
7979
compressed_q_idx = q_idx // fine_block_size
8080
compressed_kv_idx = kv_idx // fine_block_size
8181

82-
block_causal_mask = compressed_q_idx > compressed_kv_idx
8382
is_selected = one_hot_selected_block_indices[b_idx, h_idx, q_idx, compressed_kv_idx]
8483

8584
causal_mask = q_idx >= kv_idx
8685
block_diagonal = compressed_q_idx == compressed_kv_idx
8786

88-
return (causal_mask & block_diagonal) | (block_causal_mask & is_selected)
87+
return (causal_mask & (block_diagonal | is_selected))
8988

9089
block_mask = create_block_mask(fine_mask, B = batch, H = heads, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
9190
return block_mask
@@ -344,76 +343,87 @@ def forward(
344343
selected_importance_values, selected_block_indices = importance_scores.topk(num_selected, dim = -1)
345344

346345
if self.use_diff_topk:
346+
assert not exists(fine_selection_flex_mask)
347347
gates = straight_through(selected_importance_values, 1.)
348348

349-
fmask = selected_importance_values > 1e-10
349+
if exists(fine_selection_flex_mask):
350+
# flex attention for the selection for fine attention
350351

351-
if seq_len < fine_divisible_seq_len:
352-
remainder = fine_divisible_seq_len - seq_len
353-
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
354-
fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2)
355-
fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2)
352+
fk, fv, selected_block_indices = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, selected_block_indices))
356353

357-
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
354+
fine_block_mask = fine_selection_flex_mask(selected_block_indices)
358355

359-
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
356+
fine_attn_out = flex_attention(fq, fk, fv, block_mask = fine_block_mask)
360357

361-
if self.use_diff_topk:
362-
gates = pad_at_dim(gates, (0, remainder), value = 1., dim = -2)
358+
else:
359+
fmask = selected_importance_values > 1e-10
363360

364-
# handle block causal diagonal in the diagram, but run experiments without to see
361+
if seq_len < fine_divisible_seq_len:
362+
remainder = fine_divisible_seq_len - seq_len
363+
fk = pad_at_dim(fk, (0, remainder), value = 0., dim = -2)
364+
fv = pad_at_dim(fv, (0, remainder), value = 0., dim = -2)
365+
fq = pad_at_dim(fq, (0, remainder), value = 0., dim = -2)
365366

366-
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
367-
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = self.kv_heads)
368-
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
367+
fmask = pad_at_dim(fmask, (0, remainder), value = False, dim = -2)
369368

370-
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
369+
selected_block_indices = pad_at_dim(selected_block_indices, (0, remainder), value = 0, dim = -2)
371370

372-
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
373-
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = self.kv_heads)
371+
if self.use_diff_topk:
372+
gates = pad_at_dim(gates, (0, remainder), value = 1., dim = -2)
374373

375-
fmask = cat((fmask, causal_mask), dim = -2)
376-
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
374+
# handle block causal diagonal in the diagram, but run experiments without to see
377375

378-
# select out the spatial crops of keys / values for fine attention
376+
fine_window_seq = arange(fine_divisible_seq_len, device = device) // self.selection_block_size
377+
fine_window_seq = repeat(fine_window_seq, 'n -> b h n 1', b = batch, h = self.kv_heads)
378+
selected_block_indices = cat((selected_block_indices, fine_window_seq), dim = -1) # for the block causal diagonal in fig2
379379

380-
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
381-
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
380+
fmask = repeat(fmask, 'b h i w -> b h i w j', j = self.selection_block_size)
382381

383-
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
382+
causal_mask = torch.ones((self.selection_block_size,) * 2, device = device, dtype = torch.bool).tril()
383+
causal_mask = repeat(causal_mask, 'i j -> b h (w i) 1 j', w = num_fine_blocks, b = batch, h = self.kv_heads)
384384

385-
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
386-
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
385+
fmask = cat((fmask, causal_mask), dim = -2)
386+
fmask = rearrange(fmask, 'b h i w j -> b h i (w j)')
387387

388-
selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])
388+
# select out the spatial crops of keys / values for fine attention
389389

390-
fk = fk.gather(3, selected_block_indices)
391-
fv = fv.gather(3, selected_block_indices)
390+
fk = rearrange(fk, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
391+
fv = rearrange(fv, 'b h (w n) d -> b h w n d', w = num_fine_blocks)
392392

393-
# handle maybe gating
393+
# get_at("b h [w] j d, b h i selected -> b h i selected j d", fkv, selected_block_indices)
394394

395-
if self.use_diff_topk:
396-
gates = F.pad(gates, (0, 1), value = 1.)
395+
fk = repeat(fk, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
396+
fv = repeat(fv, 'b h w j d -> b h i w j d', i = selected_block_indices.shape[2])
397397

398-
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
399-
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)
398+
selected_block_indices = repeat(selected_block_indices, 'b h i sel -> b h i sel j d', j = fk.shape[-2], d = fk.shape[-1])
400399

401-
fk = rearrange(fk, 'b h i w j d -> b h i (w j) d')
402-
fv = rearrange(fv, 'b h i w j d -> b h i (w j) d')
400+
fk = fk.gather(3, selected_block_indices)
401+
fv = fv.gather(3, selected_block_indices)
403402

404-
# fine attention
403+
# handle maybe gating
405404

406-
fk, fv, fmask = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, fmask))
405+
if self.use_diff_topk:
406+
gates = F.pad(gates, (0, 1), value = 1.)
407407

408-
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
408+
fk = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fk)
409+
fv = einx.multiply('b h i w, b h i w j d -> b h i w j d', gates, fv)
409410

410-
fsim = fsim.masked_fill(~fmask, mask_value)
411+
fk = rearrange(fk, 'b h i w j d -> b h i (w j) d')
412+
fv = rearrange(fv, 'b h i w j d -> b h i (w j) d')
411413

412-
fattn = fsim.softmax(dim = -1)
414+
# fine attention
415+
416+
fk, fv, fmask = tuple(repeat(t, 'b h ... -> b (num_grouped_queries h) ...', num_grouped_queries = self.num_grouped_queries) for t in (fk, fv, fmask))
417+
418+
fsim = einsum(fq, fk, 'b h i d, b h i j d -> b h i j') * self.scale
419+
420+
fsim = fsim.masked_fill(~fmask, mask_value)
421+
422+
fattn = fsim.softmax(dim = -1)
413423

414-
fine_attn_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
424+
fine_attn_out = einsum(fattn, fv, 'b h i j, b h i j d -> b h i d')
415425

416-
fine_attn_out = fine_attn_out[..., :seq_len, :]
426+
fine_attn_out = fine_attn_out[..., :seq_len, :]
417427
else:
418428
# if only first block, just do a simple block causal
419429

native_sparse_attention_pytorch/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def forward(
196196
)
197197

198198
if not disable_flex and self.use_flex_fine_selection:
199-
attn_kwargs.udpate(
199+
attn_kwargs.update(
200200
fine_selection_flex_mask = create_fine_mask(seq_len, self.attn_fine_block_size)
201201
)
202202

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.28"
3+
version = "0.0.29"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
SEQ_LEN = 256
2525

2626
USE_SPARSE_ATTN = True
27-
USE_FLEX_FOR_FINE_SELECTION = False # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
27+
USE_FLEX_FOR_FINE_SELECTION = True # will push flex a bit, won't be efficient as each layer needs sparsity dynmically generated, but may be enough just to compare to full attention before going all-in on triton kernels
2828

2929
# experiment related
3030

0 commit comments

Comments
 (0)