Skip to content

Commit 0adc5a3

Browse files
Use the right version of CCT
1 parent 15250c8 commit 0adc5a3

18 files changed

+4184
-13
lines changed

Tests/Models/RunCCT/CCT/CCT/cct.py

Lines changed: 606 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def register_model(func):
2+
"""
3+
Fallback wrapper in case timm isn't installed
4+
"""
5+
return func

Tests/Models/RunCCT/CCT/CCT/utils/__init__.py

Whitespace-only changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch.nn as nn
2+
3+
4+
class Embedder(nn.Module):
5+
6+
def __init__(self,
7+
word_embedding_dim = 300,
8+
vocab_size = 100000,
9+
padding_idx = 1,
10+
pretrained_weight = None,
11+
embed_freeze = False,
12+
*args,
13+
**kwargs):
14+
super(Embedder, self).__init__()
15+
self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \
16+
if pretrained_weight is not None else \
17+
nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)
18+
self.embeddings.weight.requires_grad = not embed_freeze
19+
20+
def forward_mask(self, mask):
21+
bsz, seq_len = mask.shape
22+
new_mask = mask.view(bsz, seq_len, 1)
23+
new_mask = new_mask.sum(-1)
24+
new_mask = (new_mask > 0)
25+
return new_mask
26+
27+
def forward(self, x, mask = None):
28+
embed = self.embeddings(x)
29+
embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float()
30+
return embed, mask
31+
32+
@staticmethod
33+
def init_weight(m):
34+
if isinstance(m, nn.Linear):
35+
nn.init.trunc_normal_(m.weight, std = .02)
36+
if isinstance(m, nn.Linear) and m.bias is not None:
37+
nn.init.constant_(m.bias, 0)
38+
else:
39+
nn.init.normal_(m.weight)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
import math
3+
4+
import torch
5+
import torch.nn.functional as F
6+
7+
_logger = logging.getLogger('train')
8+
9+
10+
def resize_pos_embed(posemb, posemb_new, num_tokens = 1):
11+
# Copied from `timm` by Ross Wightman:
12+
# github.com/rwightman/pytorch-image-models
13+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
14+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
15+
ntok_new = posemb_new.shape[1]
16+
if num_tokens:
17+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
18+
ntok_new -= num_tokens
19+
else:
20+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
21+
gs_old = int(math.sqrt(len(posemb_grid)))
22+
gs_new = int(math.sqrt(ntok_new))
23+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
24+
posemb_grid = F.interpolate(posemb_grid, size = (gs_new, gs_new), mode = 'bilinear')
25+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
26+
posemb = torch.cat([posemb_tok, posemb_grid], dim = 1)
27+
return posemb
28+
29+
30+
def pe_check(model, state_dict, pe_key = 'classifier.positional_emb'):
31+
if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys():
32+
if model.state_dict()[pe_key].shape != state_dict[pe_key].shape:
33+
state_dict[pe_key] = resize_pos_embed(state_dict[pe_key],
34+
model.state_dict()[pe_key],
35+
num_tokens = model.classifier.num_tokens)
36+
return state_dict
37+
38+
39+
def fc_check(model, state_dict, fc_key = 'classifier.fc'):
40+
for key in [f'{fc_key}.weight', f'{fc_key}.bias']:
41+
if key is not None and key in state_dict.keys() and key in model.state_dict().keys():
42+
if model.state_dict()[key].shape != state_dict[key].shape:
43+
_logger.warning(f'Removing {key}, number of classes has changed.')
44+
state_dict[key] = model.state_dict()[key]
45+
return state_dict
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Thanks to rwightman's timm package
2+
# github.com:rwightman/pytorch-image-models
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
8+
def drop_path(x, drop_prob: float = 0., training: bool = False):
9+
"""
10+
Obtained from: github.com:rwightman/pytorch-image-models
11+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16+
'survival rate' as the argument.
17+
"""
18+
if drop_prob == 0. or not training:
19+
return x
20+
keep_prob = 1 - drop_prob
21+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
22+
random_tensor = keep_prob + torch.rand(shape, dtype = x.dtype, device = x.device)
23+
random_tensor.floor_() # binarize
24+
output = x.div(keep_prob) * random_tensor
25+
return output
26+
27+
28+
class DropPath(nn.Module):
29+
"""
30+
Obtained from: github.com:rwightman/pytorch-image-models
31+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32+
"""
33+
34+
def __init__(self, drop_prob = None):
35+
super(DropPath, self).__init__()
36+
self.drop_prob = drop_prob
37+
38+
def forward(self, x):
39+
return drop_path(x, self.drop_prob, self.training)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class Tokenizer(nn.Module):
7+
8+
def __init__(self,
9+
kernel_size,
10+
stride,
11+
padding,
12+
pooling_kernel_size = 3,
13+
pooling_stride = 2,
14+
pooling_padding = 1,
15+
n_conv_layers = 1,
16+
n_input_channels = 3,
17+
n_output_channels = 64,
18+
in_planes = 64,
19+
activation = None,
20+
max_pool = True,
21+
conv_bias = False):
22+
super(Tokenizer, self).__init__()
23+
24+
n_filter_list = [n_input_channels] + \
25+
[in_planes for _ in range(n_conv_layers - 1)] + \
26+
[n_output_channels]
27+
28+
self.conv_layers = nn.Sequential(*[
29+
nn.Sequential(
30+
nn.Conv2d(n_filter_list[i],
31+
n_filter_list[i + 1],
32+
kernel_size = (kernel_size, kernel_size),
33+
stride = (stride, stride),
34+
padding = (padding, padding),
35+
bias = conv_bias),
36+
nn.Identity() if activation is None else activation(),
37+
nn.MaxPool2d(kernel_size = pooling_kernel_size, stride = pooling_stride, padding = pooling_padding
38+
) if max_pool else nn.Identity()) for i in range(n_conv_layers)
39+
])
40+
41+
self.flattener = nn.Flatten(2, 3)
42+
self.apply(self.init_weight)
43+
44+
def sequence_length(self, n_channels = 3, height = 224, width = 224):
45+
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
46+
47+
def forward(self, x):
48+
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
49+
50+
@staticmethod
51+
def init_weight(m):
52+
if isinstance(m, nn.Conv2d):
53+
nn.init.kaiming_normal_(m.weight)
54+
55+
56+
class TextTokenizer(nn.Module):
57+
58+
def __init__(self,
59+
kernel_size,
60+
stride,
61+
padding,
62+
pooling_kernel_size = 3,
63+
pooling_stride = 2,
64+
pooling_padding = 1,
65+
embedding_dim = 300,
66+
n_output_channels = 128,
67+
activation = None,
68+
max_pool = True,
69+
*args,
70+
**kwargs):
71+
super(TextTokenizer, self).__init__()
72+
73+
self.max_pool = max_pool
74+
self.conv_layers = nn.Sequential(
75+
nn.Conv2d(1,
76+
n_output_channels,
77+
kernel_size = (kernel_size, embedding_dim),
78+
stride = (stride, 1),
79+
padding = (padding, 0),
80+
bias = False),
81+
nn.Identity() if activation is None else activation(),
82+
nn.MaxPool2d(
83+
kernel_size = (pooling_kernel_size, 1), stride = (pooling_stride,
84+
1), padding = (pooling_padding,
85+
0)) if max_pool else nn.Identity())
86+
87+
self.apply(self.init_weight)
88+
89+
def seq_len(self, seq_len = 32, embed_dim = 300):
90+
return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]
91+
92+
def forward_mask(self, mask):
93+
new_mask = mask.unsqueeze(1).float()
94+
cnn_weight = torch.ones((1, 1, self.conv_layers[0].kernel_size[0]), device = mask.device, dtype = torch.float)
95+
new_mask = F.conv1d(new_mask, cnn_weight, None, self.conv_layers[0].stride[0], self.conv_layers[0].padding[0],
96+
1, 1)
97+
if self.max_pool:
98+
new_mask = F.max_pool1d(new_mask, self.conv_layers[2].kernel_size[0], self.conv_layers[2].stride[0],
99+
self.conv_layers[2].padding[0], 1, False, False)
100+
new_mask = new_mask.squeeze(1)
101+
new_mask = (new_mask > 0)
102+
return new_mask
103+
104+
def forward(self, x, mask = None):
105+
x = x.unsqueeze(1)
106+
x = self.conv_layers(x)
107+
x = x.transpose(1, 3).squeeze(1)
108+
if mask is not None:
109+
mask = self.forward_mask(mask).unsqueeze(-1).float()
110+
x = x * mask
111+
return x, mask
112+
113+
@staticmethod
114+
def init_weight(m):
115+
if isinstance(m, nn.Conv2d):
116+
nn.init.kaiming_normal_(m.weight)

0 commit comments

Comments
 (0)