Skip to content

Commit cfd21c3

Browse files
Add CCT Model and Test
1 parent 316bd96 commit cfd21c3

File tree

17 files changed

+1853
-0
lines changed

17 files changed

+1853
-0
lines changed

Tests/Models/CCT/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .cct import *
2+
from .cvt import *
3+
from .vit import *

Tests/Models/CCT/cct.py

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
from torch.hub import load_state_dict_from_url
2+
import torch.nn as nn
3+
from .utils.transformers import TransformerClassifier
4+
from .utils.tokenizer import Tokenizer
5+
from .utils.helpers import pe_check, fc_check
6+
7+
try:
8+
from timm.models.registry import register_model
9+
except ImportError:
10+
from .registry import register_model
11+
12+
model_urls = {
13+
'cct_7_3x1_32':
14+
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar10_300epochs.pth',
15+
'cct_7_3x1_32_sine':
16+
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar10_5000epochs.pth',
17+
'cct_7_3x1_32_c100':
18+
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar100_300epochs.pth',
19+
'cct_7_3x1_32_sine_c100':
20+
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar100_5000epochs.pth',
21+
'cct_7_7x2_224_sine':
22+
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_7x2_224_flowers102.pth',
23+
'cct_14_7x2_224':
24+
'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_14_7x2_224_imagenet.pth',
25+
'cct_14_7x2_384':
26+
'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_imagenet.pth',
27+
'cct_14_7x2_384_fl':
28+
'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_flowers102.pth',
29+
}
30+
31+
32+
class CCT(nn.Module):
33+
def __init__(self,
34+
img_size=224,
35+
embedding_dim=768,
36+
n_input_channels=3,
37+
n_conv_layers=1,
38+
kernel_size=7,
39+
stride=2,
40+
padding=3,
41+
pooling_kernel_size=3,
42+
pooling_stride=2,
43+
pooling_padding=1,
44+
dropout=0.,
45+
attention_dropout=0.1,
46+
stochastic_depth=0.1,
47+
num_layers=14,
48+
num_heads=6,
49+
mlp_ratio=4.0,
50+
num_classes=1000,
51+
positional_embedding='learnable',
52+
*args, **kwargs):
53+
super(CCT, self).__init__()
54+
55+
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
56+
n_output_channels=embedding_dim,
57+
kernel_size=kernel_size,
58+
stride=stride,
59+
padding=padding,
60+
pooling_kernel_size=pooling_kernel_size,
61+
pooling_stride=pooling_stride,
62+
pooling_padding=pooling_padding,
63+
max_pool=True,
64+
activation=nn.ReLU,
65+
n_conv_layers=n_conv_layers,
66+
conv_bias=False)
67+
68+
self.classifier = TransformerClassifier(
69+
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
70+
height=img_size,
71+
width=img_size),
72+
embedding_dim=embedding_dim,
73+
seq_pool=True,
74+
dropout=dropout,
75+
attention_dropout=attention_dropout,
76+
stochastic_depth=stochastic_depth,
77+
num_layers=num_layers,
78+
num_heads=num_heads,
79+
mlp_ratio=mlp_ratio,
80+
num_classes=num_classes,
81+
positional_embedding=positional_embedding
82+
)
83+
84+
def forward(self, x):
85+
x = self.tokenizer(x)
86+
return self.classifier(x)
87+
88+
89+
def _cct(arch, pretrained, progress,
90+
num_layers, num_heads, mlp_ratio, embedding_dim,
91+
kernel_size=3, stride=None, padding=None,
92+
positional_embedding='learnable',
93+
*args, **kwargs):
94+
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
95+
padding = padding if padding is not None else max(1, (kernel_size // 2))
96+
model = CCT(num_layers=num_layers,
97+
num_heads=num_heads,
98+
mlp_ratio=mlp_ratio,
99+
embedding_dim=embedding_dim,
100+
kernel_size=kernel_size,
101+
stride=stride,
102+
padding=padding,
103+
*args, **kwargs)
104+
105+
if pretrained:
106+
if arch in model_urls:
107+
state_dict = load_state_dict_from_url(model_urls[arch],
108+
progress=progress)
109+
if positional_embedding == 'learnable':
110+
state_dict = pe_check(model, state_dict)
111+
elif positional_embedding == 'sine':
112+
state_dict['classifier.positional_emb'] = model.state_dict()['classifier.positional_emb']
113+
state_dict = fc_check(model, state_dict)
114+
model.load_state_dict(state_dict)
115+
else:
116+
raise RuntimeError(f'Variant {arch} does not yet have pretrained weights.')
117+
return model
118+
119+
120+
@register_model
121+
def cct_2(arch, pretrained, progress, *args, **kwargs):
122+
return _cct(arch, pretrained, progress, num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
123+
*args, **kwargs)
124+
125+
126+
@register_model
127+
def cct_4(arch, pretrained, progress, *args, **kwargs):
128+
return _cct(arch, pretrained, progress, num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
129+
*args, **kwargs)
130+
131+
132+
@register_model
133+
def cct_6(arch, pretrained, progress, *args, **kwargs):
134+
return _cct(arch, pretrained, progress, num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
135+
*args, **kwargs)
136+
137+
138+
@register_model
139+
def cct_7(arch, pretrained, progress, *args, **kwargs):
140+
return _cct(arch, pretrained, progress, num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
141+
*args, **kwargs)
142+
143+
144+
@register_model
145+
def cct_14(arch, pretrained, progress, *args, **kwargs):
146+
return _cct(arch, pretrained, progress, num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
147+
*args, **kwargs)
148+
149+
150+
@register_model
151+
def cct_2_3x2_32(pretrained=False, progress=False,
152+
img_size=32, positional_embedding='learnable', num_classes=10,
153+
*args, **kwargs):
154+
return cct_2('cct_2_3x2_32', pretrained, progress,
155+
kernel_size=3, n_conv_layers=2,
156+
img_size=img_size, positional_embedding=positional_embedding,
157+
num_classes=num_classes,
158+
*args, **kwargs)
159+
160+
161+
@register_model
162+
def cct_2_3x2_32_sine(pretrained=False, progress=False,
163+
img_size=32, positional_embedding='sine', num_classes=10,
164+
*args, **kwargs):
165+
return cct_2('cct_2_3x2_32_sine', pretrained, progress,
166+
kernel_size=3, n_conv_layers=2,
167+
img_size=img_size, positional_embedding=positional_embedding,
168+
num_classes=num_classes,
169+
*args, **kwargs)
170+
171+
172+
@register_model
173+
def cct_4_3x2_32(pretrained=False, progress=False,
174+
img_size=32, positional_embedding='learnable', num_classes=10,
175+
*args, **kwargs):
176+
return cct_4('cct_4_3x2_32', pretrained, progress,
177+
kernel_size=3, n_conv_layers=2,
178+
img_size=img_size, positional_embedding=positional_embedding,
179+
num_classes=num_classes,
180+
*args, **kwargs)
181+
182+
183+
@register_model
184+
def cct_4_3x2_32_sine(pretrained=False, progress=False,
185+
img_size=32, positional_embedding='sine', num_classes=10,
186+
*args, **kwargs):
187+
return cct_4('cct_4_3x2_32_sine', pretrained, progress,
188+
kernel_size=3, n_conv_layers=2,
189+
img_size=img_size, positional_embedding=positional_embedding,
190+
num_classes=num_classes,
191+
*args, **kwargs)
192+
193+
194+
@register_model
195+
def cct_6_3x1_32(pretrained=False, progress=False,
196+
img_size=32, positional_embedding='learnable', num_classes=10,
197+
*args, **kwargs):
198+
return cct_6('cct_6_3x1_32', pretrained, progress,
199+
kernel_size=3, n_conv_layers=1,
200+
img_size=img_size, positional_embedding=positional_embedding,
201+
num_classes=num_classes,
202+
*args, **kwargs)
203+
204+
205+
@register_model
206+
def cct_6_3x1_32_sine(pretrained=False, progress=False,
207+
img_size=32, positional_embedding='sine', num_classes=10,
208+
*args, **kwargs):
209+
return cct_6('cct_6_3x1_32_sine', pretrained, progress,
210+
kernel_size=3, n_conv_layers=1,
211+
img_size=img_size, positional_embedding=positional_embedding,
212+
num_classes=num_classes,
213+
*args, **kwargs)
214+
215+
216+
@register_model
217+
def cct_6_3x2_32(pretrained=False, progress=False,
218+
img_size=32, positional_embedding='learnable', num_classes=10,
219+
*args, **kwargs):
220+
return cct_6('cct_6_3x2_32', pretrained, progress,
221+
kernel_size=3, n_conv_layers=2,
222+
img_size=img_size, positional_embedding=positional_embedding,
223+
num_classes=num_classes,
224+
*args, **kwargs)
225+
226+
227+
@register_model
228+
def cct_6_3x2_32_sine(pretrained=False, progress=False,
229+
img_size=32, positional_embedding='sine', num_classes=10,
230+
*args, **kwargs):
231+
return cct_6('cct_6_3x2_32_sine', pretrained, progress,
232+
kernel_size=3, n_conv_layers=2,
233+
img_size=img_size, positional_embedding=positional_embedding,
234+
num_classes=num_classes,
235+
*args, **kwargs)
236+
237+
238+
@register_model
239+
def cct_7_3x1_32(pretrained=False, progress=False,
240+
img_size=32, positional_embedding='learnable', num_classes=10,
241+
*args, **kwargs):
242+
return cct_7('cct_7_3x1_32', pretrained, progress,
243+
kernel_size=3, n_conv_layers=1,
244+
img_size=img_size, positional_embedding=positional_embedding,
245+
num_classes=num_classes,
246+
*args, **kwargs)
247+
248+
249+
@register_model
250+
def cct_7_3x1_32_sine(pretrained=False, progress=False,
251+
img_size=32, positional_embedding='sine', num_classes=10,
252+
*args, **kwargs):
253+
return cct_7('cct_7_3x1_32_sine', pretrained, progress,
254+
kernel_size=3, n_conv_layers=1,
255+
img_size=img_size, positional_embedding=positional_embedding,
256+
num_classes=num_classes,
257+
*args, **kwargs)
258+
259+
260+
@register_model
261+
def cct_7_3x1_32_c100(pretrained=False, progress=False,
262+
img_size=32, positional_embedding='learnable', num_classes=100,
263+
*args, **kwargs):
264+
return cct_7('cct_7_3x1_32_c100', pretrained, progress,
265+
kernel_size=3, n_conv_layers=1,
266+
img_size=img_size, positional_embedding=positional_embedding,
267+
num_classes=num_classes,
268+
*args, **kwargs)
269+
270+
271+
@register_model
272+
def cct_7_3x1_32_sine_c100(pretrained=False, progress=False,
273+
img_size=32, positional_embedding='sine', num_classes=100,
274+
*args, **kwargs):
275+
return cct_7('cct_7_3x1_32_sine_c100', pretrained, progress,
276+
kernel_size=3, n_conv_layers=1,
277+
img_size=img_size, positional_embedding=positional_embedding,
278+
num_classes=num_classes,
279+
*args, **kwargs)
280+
281+
282+
@register_model
283+
def cct_7_3x2_32(pretrained=False, progress=False,
284+
img_size=32, positional_embedding='learnable', num_classes=10,
285+
*args, **kwargs):
286+
return cct_7('cct_7_3x2_32', pretrained, progress,
287+
kernel_size=3, n_conv_layers=2,
288+
img_size=img_size, positional_embedding=positional_embedding,
289+
num_classes=num_classes,
290+
*args, **kwargs)
291+
292+
293+
@register_model
294+
def cct_7_3x2_32_sine(pretrained=False, progress=False,
295+
img_size=32, positional_embedding='sine', num_classes=10,
296+
*args, **kwargs):
297+
return cct_7('cct_7_3x2_32_sine', pretrained, progress,
298+
kernel_size=3, n_conv_layers=2,
299+
img_size=img_size, positional_embedding=positional_embedding,
300+
num_classes=num_classes,
301+
*args, **kwargs)
302+
303+
304+
@register_model
305+
def cct_7_7x2_224(pretrained=False, progress=False,
306+
img_size=224, positional_embedding='learnable', num_classes=102,
307+
*args, **kwargs):
308+
return cct_7('cct_7_7x2_224', pretrained, progress,
309+
kernel_size=7, n_conv_layers=2,
310+
img_size=img_size, positional_embedding=positional_embedding,
311+
num_classes=num_classes,
312+
*args, **kwargs)
313+
314+
315+
@register_model
316+
def cct_7_7x2_224_sine(pretrained=False, progress=False,
317+
img_size=224, positional_embedding='sine', num_classes=102,
318+
*args, **kwargs):
319+
return cct_7('cct_7_7x2_224_sine', pretrained, progress,
320+
kernel_size=7, n_conv_layers=2,
321+
img_size=img_size, positional_embedding=positional_embedding,
322+
num_classes=num_classes,
323+
*args, **kwargs)
324+
325+
326+
@register_model
327+
def cct_14_7x2_224(pretrained=False, progress=False,
328+
img_size=224, positional_embedding='learnable', num_classes=1000,
329+
*args, **kwargs):
330+
return cct_14('cct_14_7x2_224', pretrained, progress,
331+
kernel_size=7, n_conv_layers=2,
332+
img_size=img_size, positional_embedding=positional_embedding,
333+
num_classes=num_classes,
334+
*args, **kwargs)
335+
336+
337+
@register_model
338+
def cct_14_7x2_384(pretrained=False, progress=False,
339+
img_size=384, positional_embedding='learnable', num_classes=1000,
340+
*args, **kwargs):
341+
return cct_14('cct_14_7x2_384', pretrained, progress,
342+
kernel_size=7, n_conv_layers=2,
343+
img_size=img_size, positional_embedding=positional_embedding,
344+
num_classes=num_classes,
345+
*args, **kwargs)
346+
347+
348+
@register_model
349+
def cct_14_7x2_384_fl(pretrained=False, progress=False,
350+
img_size=384, positional_embedding='learnable', num_classes=102,
351+
*args, **kwargs):
352+
return cct_14('cct_14_7x2_384_fl', pretrained, progress,
353+
kernel_size=7, n_conv_layers=2,
354+
img_size=img_size, positional_embedding=positional_embedding,
355+
num_classes=num_classes,
356+
*args, **kwargs)

0 commit comments

Comments
 (0)