Skip to content

Commit 8208c85

Browse files
committed
just remove PreNorm wrapper from all ViTs, as it is unlikely to change at this point
1 parent 4264efd commit 8208c85

21 files changed

+137
-232
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.4.2',
6+
version = '1.4.4',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/ats_vit.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,11 @@ def forward(self, attn, value, mask):
110110

111111
# classes
112112

113-
class PreNorm(nn.Module):
114-
def __init__(self, dim, fn):
115-
super().__init__()
116-
self.norm = nn.LayerNorm(dim)
117-
self.fn = fn
118-
def forward(self, x, **kwargs):
119-
return self.fn(self.norm(x), **kwargs)
120-
121113
class FeedForward(nn.Module):
122114
def __init__(self, dim, hidden_dim, dropout = 0.):
123115
super().__init__()
124116
self.net = nn.Sequential(
117+
nn.LayerNorm(dim),
125118
nn.Linear(dim, hidden_dim),
126119
nn.GELU(),
127120
nn.Dropout(dropout),
@@ -138,6 +131,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token
138131
self.heads = heads
139132
self.scale = dim_head ** -0.5
140133

134+
self.norm = nn.LayerNorm(dim)
141135
self.attend = nn.Softmax(dim = -1)
142136
self.dropout = nn.Dropout(dropout)
143137

@@ -154,6 +148,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token
154148
def forward(self, x, *, mask):
155149
num_tokens = x.shape[1]
156150

151+
x = self.norm(x)
157152
qkv = self.to_qkv(x).chunk(3, dim = -1)
158153
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
159154

@@ -189,8 +184,8 @@ def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, d
189184
self.layers = nn.ModuleList([])
190185
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
191186
self.layers.append(nn.ModuleList([
192-
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
193-
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
187+
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
188+
FeedForward(dim, mlp_dim, dropout = dropout)
194189
]))
195190

196191
def forward(self, x):

vit_pytorch/cait.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,11 @@ def __init__(self, dim, fn, depth):
4444
def forward(self, x, **kwargs):
4545
return self.fn(x, **kwargs) * self.scale
4646

47-
class PreNorm(nn.Module):
48-
def __init__(self, dim, fn):
49-
super().__init__()
50-
self.norm = nn.LayerNorm(dim)
51-
self.fn = fn
52-
def forward(self, x, **kwargs):
53-
return self.fn(self.norm(x), **kwargs)
54-
5547
class FeedForward(nn.Module):
5648
def __init__(self, dim, hidden_dim, dropout = 0.):
5749
super().__init__()
5850
self.net = nn.Sequential(
51+
nn.LayerNorm(dim),
5952
nn.Linear(dim, hidden_dim),
6053
nn.GELU(),
6154
nn.Dropout(dropout),
@@ -72,6 +65,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
7265
self.heads = heads
7366
self.scale = dim_head ** -0.5
7467

68+
self.norm = nn.LayerNorm(dim)
7569
self.to_q = nn.Linear(dim, inner_dim, bias = False)
7670
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
7771

@@ -89,6 +83,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
8983
def forward(self, x, context = None):
9084
b, n, _, h = *x.shape, self.heads
9185

86+
x = self.norm(x)
9287
context = x if not exists(context) else torch.cat((x, context), dim = 1)
9388

9489
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
@@ -115,8 +110,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dro
115110

116111
for ind in range(depth):
117112
self.layers.append(nn.ModuleList([
118-
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
119-
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
113+
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
114+
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
120115
]))
121116
def forward(self, x, context = None):
122117
layers = dropout_layers(self.layers, dropout = self.layer_dropout)

vit_pytorch/cross_vit.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,13 @@ def exists(val):
1313
def default(val, d):
1414
return val if exists(val) else d
1515

16-
# pre-layernorm
17-
18-
class PreNorm(nn.Module):
19-
def __init__(self, dim, fn):
20-
super().__init__()
21-
self.norm = nn.LayerNorm(dim)
22-
self.fn = fn
23-
def forward(self, x, **kwargs):
24-
return self.fn(self.norm(x), **kwargs)
25-
2616
# feedforward
2717

2818
class FeedForward(nn.Module):
2919
def __init__(self, dim, hidden_dim, dropout = 0.):
3020
super().__init__()
3121
self.net = nn.Sequential(
22+
nn.LayerNorm(dim),
3223
nn.Linear(dim, hidden_dim),
3324
nn.GELU(),
3425
nn.Dropout(dropout),
@@ -47,6 +38,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4738
self.heads = heads
4839
self.scale = dim_head ** -0.5
4940

41+
self.norm = nn.LayerNorm(dim)
5042
self.attend = nn.Softmax(dim = -1)
5143
self.dropout = nn.Dropout(dropout)
5244

@@ -60,6 +52,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
6052

6153
def forward(self, x, context = None, kv_include_self = False):
6254
b, n, _, h = *x.shape, self.heads
55+
x = self.norm(x)
6356
context = default(context, x)
6457

6558
if kv_include_self:
@@ -86,8 +79,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
8679
self.norm = nn.LayerNorm(dim)
8780
for _ in range(depth):
8881
self.layers.append(nn.ModuleList([
89-
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
90-
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
82+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
83+
FeedForward(dim, mlp_dim, dropout = dropout)
9184
]))
9285

9386
def forward(self, x):
@@ -121,8 +114,8 @@ def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
121114
self.layers = nn.ModuleList([])
122115
for _ in range(depth):
123116
self.layers.append(nn.ModuleList([
124-
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
125-
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
117+
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
118+
ProjectInOut(lg_dim, sm_dim, ttention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
126119
]))
127120

128121
def forward(self, sm_tokens, lg_tokens):

vit_pytorch/cvt.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,11 @@ def forward(self, x):
3434
mean = torch.mean(x, dim = 1, keepdim = True)
3535
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
3636

37-
class PreNorm(nn.Module):
38-
def __init__(self, dim, fn):
39-
super().__init__()
40-
self.norm = LayerNorm(dim)
41-
self.fn = fn
42-
def forward(self, x, **kwargs):
43-
x = self.norm(x)
44-
return self.fn(x, **kwargs)
45-
4637
class FeedForward(nn.Module):
4738
def __init__(self, dim, mult = 4, dropout = 0.):
4839
super().__init__()
4940
self.net = nn.Sequential(
41+
LayerNorm(dim),
5042
nn.Conv2d(dim, dim * mult, 1),
5143
nn.GELU(),
5244
nn.Dropout(dropout),
@@ -75,6 +67,7 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d
7567
self.heads = heads
7668
self.scale = dim_head ** -0.5
7769

70+
self.norm = LayerNorm(dim)
7871
self.attend = nn.Softmax(dim = -1)
7972
self.dropout = nn.Dropout(dropout)
8073

@@ -89,6 +82,8 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d
8982
def forward(self, x):
9083
shape = x.shape
9184
b, n, _, y, h = *shape, self.heads
85+
86+
x = self.norm(x)
9287
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
9388
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
9489

@@ -107,8 +102,8 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64
107102
self.layers = nn.ModuleList([])
108103
for _ in range(depth):
109104
self.layers.append(nn.ModuleList([
110-
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
111-
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
105+
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
106+
FeedForward(dim, mlp_mult, dropout = dropout)
112107
]))
113108
def forward(self, x):
114109
for attn, ff in self.layers:

vit_pytorch/deepvit.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,11 @@
55
from einops import rearrange, repeat
66
from einops.layers.torch import Rearrange
77

8-
class Residual(nn.Module):
9-
def __init__(self, fn):
10-
super().__init__()
11-
self.fn = fn
12-
def forward(self, x, **kwargs):
13-
return self.fn(x, **kwargs) + x
14-
15-
class PreNorm(nn.Module):
16-
def __init__(self, dim, fn):
17-
super().__init__()
18-
self.norm = nn.LayerNorm(dim)
19-
self.fn = fn
20-
def forward(self, x, **kwargs):
21-
return self.fn(self.norm(x), **kwargs)
22-
238
class FeedForward(nn.Module):
249
def __init__(self, dim, hidden_dim, dropout = 0.):
2510
super().__init__()
2611
self.net = nn.Sequential(
12+
nn.LayerNorm(dim),
2713
nn.Linear(dim, hidden_dim),
2814
nn.GELU(),
2915
nn.Dropout(dropout),
@@ -40,6 +26,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4026
self.heads = heads
4127
self.scale = dim_head ** -0.5
4228

29+
self.norm = nn.LayerNorm(dim)
4330
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
4431

4532
self.dropout = nn.Dropout(dropout)
@@ -59,6 +46,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
5946

6047
def forward(self, x):
6148
b, n, _, h = *x.shape, self.heads
49+
x = self.norm(x)
50+
6251
qkv = self.to_qkv(x).chunk(3, dim = -1)
6352
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
6453

@@ -86,13 +75,13 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
8675
self.layers = nn.ModuleList([])
8776
for _ in range(depth):
8877
self.layers.append(nn.ModuleList([
89-
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
90-
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
78+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
79+
FeedForward(dim, mlp_dim, dropout = dropout)
9180
]))
9281
def forward(self, x):
9382
for attn, ff in self.layers:
94-
x = attn(x)
95-
x = ff(x)
83+
x = attn(x) + x
84+
x = ff(x) + x
9685
return x
9786

9887
class DeepViT(nn.Module):

vit_pytorch/local_vit.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,6 @@ def forward(self, x, **kwargs):
2626
x = self.fn(x, **kwargs)
2727
return torch.cat((cls_token, x), dim = 1)
2828

29-
# prenorm
30-
31-
class PreNorm(nn.Module):
32-
def __init__(self, dim, fn):
33-
super().__init__()
34-
self.norm = nn.LayerNorm(dim)
35-
self.fn = fn
36-
def forward(self, x, **kwargs):
37-
return self.fn(self.norm(x), **kwargs)
38-
3929
# feed forward related classes
4030

4131
class DepthWiseConv2d(nn.Module):
@@ -52,6 +42,7 @@ class FeedForward(nn.Module):
5242
def __init__(self, dim, hidden_dim, dropout = 0.):
5343
super().__init__()
5444
self.net = nn.Sequential(
45+
nn.LayerNorm(dim),
5546
nn.Conv2d(dim, hidden_dim, 1),
5647
nn.Hardswish(),
5748
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
@@ -77,6 +68,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
7768
self.heads = heads
7869
self.scale = dim_head ** -0.5
7970

71+
self.norm = nn.LayerNorm(dim)
8072
self.attend = nn.Softmax(dim = -1)
8173
self.dropout = nn.Dropout(dropout)
8274
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -88,6 +80,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
8880

8981
def forward(self, x):
9082
b, n, _, h = *x.shape, self.heads
83+
84+
x = self.norm(x)
9185
qkv = self.to_qkv(x).chunk(3, dim = -1)
9286
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
9387

@@ -106,8 +100,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
106100
self.layers = nn.ModuleList([])
107101
for _ in range(depth):
108102
self.layers.append(nn.ModuleList([
109-
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
110-
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
103+
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
104+
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
111105
]))
112106
def forward(self, x):
113107
for attn, ff in self.layers:

vit_pytorch/max_vit.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@ def cast_tuple(val, length = 1):
1919

2020
# helper classes
2121

22-
class PreNormResidual(nn.Module):
22+
class Residual(nn.Module):
2323
def __init__(self, dim, fn):
2424
super().__init__()
25-
self.norm = nn.LayerNorm(dim)
2625
self.fn = fn
2726

2827
def forward(self, x):
29-
return self.fn(self.norm(x)) + x
28+
return self.fn(x) + x
3029

3130
class FeedForward(nn.Module):
3231
def __init__(self, dim, mult = 4, dropout = 0.):
3332
super().__init__()
3433
inner_dim = int(dim * mult)
3534
self.net = nn.Sequential(
35+
nn.LayerNorm(dim),
3636
nn.Linear(dim, inner_dim),
3737
nn.GELU(),
3838
nn.Dropout(dropout),
@@ -132,6 +132,7 @@ def __init__(
132132
self.heads = dim // dim_head
133133
self.scale = dim_head ** -0.5
134134

135+
self.norm = nn.LayerNorm(dim)
135136
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
136137

137138
self.attend = nn.Sequential(
@@ -160,6 +161,8 @@ def __init__(
160161
def forward(self, x):
161162
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
162163

164+
x = self.norm(x)
165+
163166
# flatten
164167

165168
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
@@ -259,13 +262,13 @@ def __init__(
259262
shrinkage_rate = mbconv_shrinkage_rate
260263
),
261264
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
262-
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
263-
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
265+
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
266+
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
264267
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
265268

266269
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
267-
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
268-
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
270+
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
271+
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
269272
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
270273
)
271274

0 commit comments

Comments
 (0)