Skip to content

Commit b194359

Browse files
committed
add a simple vit with qknorm, since authors seem to be promoting the technique on twitter
1 parent 950c901 commit b194359

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
5+
from einops import rearrange
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def pair(t):
11+
return t if isinstance(t, tuple) else (t, t)
12+
13+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
14+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
15+
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
16+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
17+
omega = 1.0 / (temperature ** omega)
18+
19+
y = y.flatten()[:, None] * omega[None, :]
20+
x = x.flatten()[:, None] * omega[None, :]
21+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
22+
return pe.type(dtype)
23+
24+
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
25+
26+
# in latest tweet, seem to claim more stable training at higher learning rates
27+
# unsure if this has taken off within Brain, or it has some hidden drawback
28+
29+
class RMSNorm(nn.Module):
30+
def __init__(self, heads, dim):
31+
super().__init__()
32+
self.scale = dim ** 0.5
33+
self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale)
34+
35+
def forward(self, x):
36+
normed = F.normalize(x, dim = -1)
37+
return normed * self.scale * self.gamma
38+
39+
# classes
40+
41+
class FeedForward(nn.Module):
42+
def __init__(self, dim, hidden_dim):
43+
super().__init__()
44+
self.net = nn.Sequential(
45+
nn.LayerNorm(dim),
46+
nn.Linear(dim, hidden_dim),
47+
nn.GELU(),
48+
nn.Linear(hidden_dim, dim),
49+
)
50+
def forward(self, x):
51+
return self.net(x)
52+
53+
class Attention(nn.Module):
54+
def __init__(self, dim, heads = 8, dim_head = 64):
55+
super().__init__()
56+
inner_dim = dim_head * heads
57+
self.heads = heads
58+
self.norm = nn.LayerNorm(dim)
59+
60+
self.attend = nn.Softmax(dim = -1)
61+
62+
self.q_norm = RMSNorm(heads, dim_head)
63+
self.k_norm = RMSNorm(heads, dim_head)
64+
65+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
66+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
67+
68+
def forward(self, x):
69+
x = self.norm(x)
70+
71+
qkv = self.to_qkv(x).chunk(3, dim = -1)
72+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
73+
74+
q = self.q_norm(q)
75+
k = self.k_norm(k)
76+
77+
dots = torch.matmul(q, k.transpose(-1, -2))
78+
79+
attn = self.attend(dots)
80+
81+
out = torch.matmul(attn, v)
82+
out = rearrange(out, 'b h n d -> b n (h d)')
83+
return self.to_out(out)
84+
85+
class Transformer(nn.Module):
86+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
87+
super().__init__()
88+
self.norm = nn.LayerNorm(dim)
89+
self.layers = nn.ModuleList([])
90+
for _ in range(depth):
91+
self.layers.append(nn.ModuleList([
92+
Attention(dim, heads = heads, dim_head = dim_head),
93+
FeedForward(dim, mlp_dim)
94+
]))
95+
def forward(self, x):
96+
for attn, ff in self.layers:
97+
x = attn(x) + x
98+
x = ff(x) + x
99+
return self.norm(x)
100+
101+
class SimpleViT(nn.Module):
102+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
103+
super().__init__()
104+
image_height, image_width = pair(image_size)
105+
patch_height, patch_width = pair(patch_size)
106+
107+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
108+
109+
patch_dim = channels * patch_height * patch_width
110+
111+
self.to_patch_embedding = nn.Sequential(
112+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
113+
nn.LayerNorm(patch_dim),
114+
nn.Linear(patch_dim, dim),
115+
nn.LayerNorm(dim),
116+
)
117+
118+
self.pos_embedding = posemb_sincos_2d(
119+
h = image_height // patch_height,
120+
w = image_width // patch_width,
121+
dim = dim,
122+
)
123+
124+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
125+
126+
self.pool = "mean"
127+
self.to_latent = nn.Identity()
128+
129+
self.linear_head = nn.LayerNorm(dim)
130+
131+
def forward(self, img):
132+
device = img.device
133+
134+
x = self.to_patch_embedding(img)
135+
x += self.pos_embedding.to(device, dtype=x.dtype)
136+
137+
x = self.transformer(x)
138+
x = x.mean(dim = 1)
139+
140+
x = self.to_latent(x)
141+
return self.linear_head(x)

0 commit comments

Comments
 (0)