Skip to content

Commit d531216

Browse files
committed
small test for customizable compress mlp
1 parent 8165138 commit d531216

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

native_sparse_attention_pytorch/native_sparse_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,15 @@ def __init__(
135135
compress_dim = compress_block_size * dim_head
136136
compress_mlp_dim_hidden = int(compress_mlp_expand_factor * compress_dim)
137137

138-
mlp = nn.Sequential(
138+
compress_mlp = nn.Sequential(
139139
Rearrange('b h w n d -> b h w (n d)'),
140140
nn.Linear(compress_dim, compress_mlp_dim_hidden),
141141
nn.SiLU(),
142142
nn.Linear(compress_mlp_dim_hidden, dim_head),
143143
)
144144

145-
self.k_compress = deepcopy(mlp)
146-
self.v_compress = deepcopy(mlp)
145+
self.k_compress = deepcopy(compress_mlp)
146+
self.v_compress = deepcopy(compress_mlp)
147147

148148
# selection related
149149

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.0.9"
3+
version = "0.0.11"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_sparse_attn.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import pytest
2+
23
import torch
4+
from torch import nn
5+
from einops.layers.torch import Rearrange
6+
7+
from native_sparse_attention_pytorch import SparseAttention
38

49
@pytest.mark.parametrize('use_diff_topk', (False, True))
510
def test_sparse_attn(
611
use_diff_topk
712
):
8-
from native_sparse_attention_pytorch import SparseAttention
9-
1013
attn = SparseAttention(
1114
dim = 512,
1215
dim_head = 64,
@@ -23,3 +26,34 @@ def test_sparse_attn(
2326
attended = attn(tokens)
2427

2528
assert tokens.shape == attended.shape
29+
30+
def test_alternative_compress_mlp():
31+
32+
dim_head = 64
33+
compress_dim = dim_head * 4
34+
35+
compress_mlp = nn.Sequential(
36+
Rearrange('b h w n d -> b h w (n d)'),
37+
nn.Linear(compress_dim, compress_dim),
38+
nn.SiLU(),
39+
nn.Linear(compress_dim, compress_dim),
40+
nn.SiLU(),
41+
nn.Linear(compress_dim, dim_head),
42+
)
43+
44+
attn = SparseAttention(
45+
dim = 512,
46+
dim_head = 64,
47+
heads = 8,
48+
sliding_window_size = 2,
49+
compress_block_size = 4,
50+
selection_block_size = 4,
51+
num_selected_blocks = 2,
52+
compress_mlp = compress_mlp
53+
)
54+
55+
tokens = torch.randn(2, 31, 512)
56+
57+
attended = attn(tokens)
58+
59+
assert tokens.shape == attended.shape

0 commit comments

Comments
 (0)