Skip to content

Commit 115279f

Browse files
committed
leverage einmix for a kv compress mlp with separate parameters per head
1 parent f1b9c2d commit 115279f

File tree

4 files changed

+113
-54
lines changed

4 files changed

+113
-54
lines changed

native_sparse_attention_pytorch/compress_networks.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.nn import Module, ModuleList
44

55
from einops import einsum, rearrange
6+
from einops.layers.torch import EinMix as Mix
67

78
# helpers
89

@@ -66,3 +67,34 @@ def forward(
6667

6768
return compressed
6869

70+
# mlp per head
71+
72+
class GroupedMLP(Module):
73+
def __init__(
74+
self,
75+
dim_head,
76+
compress_block_size,
77+
heads,
78+
expand_factor = 1.,
79+
):
80+
super().__init__()
81+
82+
dim = dim_head * compress_block_size
83+
dim_hidden = int(dim * expand_factor)
84+
dim_out = dim_head
85+
86+
self.net = nn.Sequential(
87+
Mix('b h w i -> b h w o', weight_shape = 'h i o', bias_shape = 'h o', h = heads, i = dim, o = dim_hidden),
88+
nn.ReLU(),
89+
Mix('b h w i -> b h w o', weight_shape = 'h i o', bias_shape = 'h o', h = heads, i = dim_hidden, o = dim_out),
90+
)
91+
92+
def forward(
93+
self,
94+
kv
95+
):
96+
kv = rearrange(kv, 'b h w n d -> b h w (n d)')
97+
98+
compressed = self.net(kv)
99+
100+
return compressed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.27"
3+
version = "0.0.28"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -24,7 +24,7 @@ classifiers=[
2424

2525
dependencies = [
2626
"einx>=0.3.0",
27-
"einops>=0.8.0",
27+
"einops>=0.8.1",
2828
"local-attention>=1.11.1",
2929
"rotary-embedding-torch",
3030
"torch>=2.5",

tests/test_custom_compress_mlp.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
3+
import torch
4+
from torch import nn
5+
from einops.layers.torch import Rearrange
6+
7+
from native_sparse_attention_pytorch import SparseAttention
8+
9+
def test_alternative_compress_mlp():
10+
11+
dim_head = 64
12+
compress_dim = dim_head * 4
13+
14+
compress_mlp = nn.Sequential(
15+
Rearrange('b h w n d -> b h w (n d)'),
16+
nn.Linear(compress_dim, compress_dim),
17+
nn.SiLU(),
18+
nn.Linear(compress_dim, compress_dim),
19+
nn.SiLU(),
20+
nn.Linear(compress_dim, dim_head),
21+
)
22+
23+
attn = SparseAttention(
24+
dim = 512,
25+
dim_head = 64,
26+
heads = 8,
27+
sliding_window_size = 2,
28+
compress_block_size = 4,
29+
selection_block_size = 4,
30+
num_selected_blocks = 2,
31+
compress_mlp = compress_mlp
32+
)
33+
34+
tokens = torch.randn(2, 31, 512)
35+
36+
attended = attn(tokens)
37+
38+
assert tokens.shape == attended.shape
39+
40+
41+
def test_compress_networks():
42+
from native_sparse_attention_pytorch.compress_networks import AttentionPool
43+
44+
attn = SparseAttention(
45+
dim = 512,
46+
dim_head = 64,
47+
heads = 8,
48+
sliding_window_size = 2,
49+
compress_block_size = 4,
50+
selection_block_size = 4,
51+
num_selected_blocks = 2,
52+
compress_mlp = AttentionPool(64, 4)
53+
)
54+
55+
tokens = torch.randn(2, 31, 512)
56+
57+
attended = attn(tokens)
58+
59+
assert tokens.shape == attended.shape
60+
61+
def test_group_mlp():
62+
from native_sparse_attention_pytorch.compress_networks import GroupedMLP
63+
64+
attn = SparseAttention(
65+
dim = 512,
66+
dim_head = 64,
67+
heads = 8,
68+
sliding_window_size = 2,
69+
compress_block_size = 4,
70+
selection_block_size = 4,
71+
num_selected_blocks = 2,
72+
compress_mlp = GroupedMLP(64, 4, 8)
73+
)
74+
75+
tokens = torch.randn(2, 31, 512)
76+
77+
attended = attn(tokens)
78+
79+
assert tokens.shape == attended.shape

tests/test_sparse_attn.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -33,55 +33,3 @@ def test_sparse_attn(
3333
attended = attn(tokens)
3434

3535
assert tokens.shape == attended.shape
36-
37-
def test_alternative_compress_mlp():
38-
39-
dim_head = 64
40-
compress_dim = dim_head * 4
41-
42-
compress_mlp = nn.Sequential(
43-
Rearrange('b h w n d -> b h w (n d)'),
44-
nn.Linear(compress_dim, compress_dim),
45-
nn.SiLU(),
46-
nn.Linear(compress_dim, compress_dim),
47-
nn.SiLU(),
48-
nn.Linear(compress_dim, dim_head),
49-
)
50-
51-
attn = SparseAttention(
52-
dim = 512,
53-
dim_head = 64,
54-
heads = 8,
55-
sliding_window_size = 2,
56-
compress_block_size = 4,
57-
selection_block_size = 4,
58-
num_selected_blocks = 2,
59-
compress_mlp = compress_mlp
60-
)
61-
62-
tokens = torch.randn(2, 31, 512)
63-
64-
attended = attn(tokens)
65-
66-
assert tokens.shape == attended.shape
67-
68-
69-
def test_compress_networks():
70-
from native_sparse_attention_pytorch.compress_networks import AttentionPool
71-
72-
attn = SparseAttention(
73-
dim = 512,
74-
dim_head = 64,
75-
heads = 8,
76-
sliding_window_size = 2,
77-
compress_block_size = 4,
78-
selection_block_size = 4,
79-
num_selected_blocks = 2,
80-
compress_mlp = AttentionPool(64, 4)
81-
)
82-
83-
tokens = torch.randn(2, 31, 512)
84-
85-
attended = attn(tokens)
86-
87-
assert tokens.shape == attended.shape

0 commit comments

Comments
 (0)