Skip to content

Commit d446a41

Browse files
committed
share an idea that should be tried if it has not been
1 parent 0ad09c4 commit d446a41

File tree

2 files changed

+163
-1
lines changed

2 files changed

+163
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.6.3',
6+
version = '1.6.4',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/simple_vit_with_fft.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import torch
2+
from torch.fft import fft
3+
from torch import nn
4+
5+
from einops import rearrange, reduce, pack, unpack
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def pair(t):
11+
return t if isinstance(t, tuple) else (t, t)
12+
13+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
14+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
15+
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
16+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
17+
omega = 1.0 / (temperature ** omega)
18+
19+
y = y.flatten()[:, None] * omega[None, :]
20+
x = x.flatten()[:, None] * omega[None, :]
21+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
22+
return pe.type(dtype)
23+
24+
# classes
25+
26+
class FeedForward(nn.Module):
27+
def __init__(self, dim, hidden_dim):
28+
super().__init__()
29+
self.net = nn.Sequential(
30+
nn.LayerNorm(dim),
31+
nn.Linear(dim, hidden_dim),
32+
nn.GELU(),
33+
nn.Linear(hidden_dim, dim),
34+
)
35+
def forward(self, x):
36+
return self.net(x)
37+
38+
class Attention(nn.Module):
39+
def __init__(self, dim, heads = 8, dim_head = 64):
40+
super().__init__()
41+
inner_dim = dim_head * heads
42+
self.heads = heads
43+
self.scale = dim_head ** -0.5
44+
self.norm = nn.LayerNorm(dim)
45+
46+
self.attend = nn.Softmax(dim = -1)
47+
48+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
49+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
50+
51+
def forward(self, x):
52+
x = self.norm(x)
53+
54+
qkv = self.to_qkv(x).chunk(3, dim = -1)
55+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
56+
57+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
58+
59+
attn = self.attend(dots)
60+
61+
out = torch.matmul(attn, v)
62+
out = rearrange(out, 'b h n d -> b n (h d)')
63+
return self.to_out(out)
64+
65+
class Transformer(nn.Module):
66+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
67+
super().__init__()
68+
self.norm = nn.LayerNorm(dim)
69+
self.layers = nn.ModuleList([])
70+
for _ in range(depth):
71+
self.layers.append(nn.ModuleList([
72+
Attention(dim, heads = heads, dim_head = dim_head),
73+
FeedForward(dim, mlp_dim)
74+
]))
75+
def forward(self, x):
76+
for attn, ff in self.layers:
77+
x = attn(x) + x
78+
x = ff(x) + x
79+
return self.norm(x)
80+
81+
class SimpleViT(nn.Module):
82+
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
83+
super().__init__()
84+
image_height, image_width = pair(image_size)
85+
patch_height, patch_width = pair(patch_size)
86+
freq_patch_height, freq_patch_width = pair(freq_patch_size)
87+
88+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
89+
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'
90+
91+
patch_dim = channels * patch_height * patch_width
92+
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width
93+
94+
self.to_patch_embedding = nn.Sequential(
95+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
96+
nn.LayerNorm(patch_dim),
97+
nn.Linear(patch_dim, dim),
98+
nn.LayerNorm(dim),
99+
)
100+
101+
self.to_freq_embedding = nn.Sequential(
102+
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
103+
nn.LayerNorm(freq_patch_dim),
104+
nn.Linear(freq_patch_dim, dim),
105+
nn.LayerNorm(dim)
106+
)
107+
108+
self.pos_embedding = posemb_sincos_2d(
109+
h = image_height // patch_height,
110+
w = image_width // patch_width,
111+
dim = dim,
112+
)
113+
114+
self.freq_pos_embedding = posemb_sincos_2d(
115+
h = image_height // freq_patch_height,
116+
w = image_width // freq_patch_width,
117+
dim = dim
118+
)
119+
120+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
121+
122+
self.pool = "mean"
123+
self.to_latent = nn.Identity()
124+
125+
self.linear_head = nn.Linear(dim, num_classes)
126+
127+
def forward(self, img):
128+
device, dtype = img.device, img.dtype
129+
130+
x = self.to_patch_embedding(img)
131+
freqs = torch.view_as_real(fft(img))
132+
133+
f = self.to_freq_embedding(freqs)
134+
135+
x += self.pos_embedding.to(device, dtype = dtype)
136+
f += self.freq_pos_embedding.to(device, dtype = dtype)
137+
138+
x, ps = pack((f, x), 'b * d')
139+
140+
x = self.transformer(x)
141+
142+
_, x = unpack(x, ps, 'b * d')
143+
x = reduce(x, 'b n d -> b d', 'mean')
144+
145+
x = self.to_latent(x)
146+
return self.linear_head(x)
147+
148+
if __name__ == '__main__':
149+
vit = SimpleViT(
150+
num_classes = 1000,
151+
image_size = 256,
152+
patch_size = 8,
153+
freq_patch_size = 8,
154+
dim = 1024,
155+
depth = 1,
156+
heads = 8,
157+
mlp_dim = 2048,
158+
)
159+
160+
images = torch.randn(8, 3, 256, 256)
161+
162+
logits = vit(images)

0 commit comments

Comments
 (0)