Skip to content

Commit f944242

Browse files
committed
Fix #262, num_classes arg mixup. Make vision_transformers a bit closer to other models wrt get/reset classfier/forward_features. Fix torchscript for ViT.
1 parent da1b90e commit f944242

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

timm/models/vision_transformer.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
107107

108108
def forward(self, x):
109109
B, N, C = x.shape
110-
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
110+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
111+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
111112

112113
attn = (q @ k.transpose(-2, -1)) * self.scale
113114
attn = attn.softmax(dim=-1)
@@ -204,6 +205,9 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
204205
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
205206
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
206207
super().__init__()
208+
self.num_classes = num_classes
209+
self.embed_dim = embed_dim
210+
207211
if hybrid_backbone is not None:
208212
self.patch_embed = HybridEmbed(
209213
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
@@ -229,7 +233,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
229233
#self.repr_act = nn.Tanh()
230234

231235
# Classifier head
232-
self.head = nn.Linear(embed_dim, num_classes)
236+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
233237

234238
trunc_normal_(self.pos_embed, std=.02)
235239
trunc_normal_(self.cls_token, std=.02)
@@ -244,11 +248,18 @@ def _init_weights(self, m):
244248
nn.init.constant_(m.bias, 0)
245249
nn.init.constant_(m.weight, 1.0)
246250

247-
@property
251+
@torch.jit.ignore
248252
def no_weight_decay(self):
249253
return {'pos_embed', 'cls_token'}
250254

251-
def forward(self, x):
255+
def get_classifier(self):
256+
return self.head
257+
258+
def reset_classifier(self, num_classes, global_pool=''):
259+
self.num_classes = num_classes
260+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
261+
262+
def forward_features(self, x):
252263
B = x.shape[0]
253264
x = self.patch_embed(x)
254265

@@ -261,7 +272,11 @@ def forward(self, x):
261272
x = blk(x)
262273

263274
x = self.norm(x)
264-
x = self.head(x[:, 0])
275+
return x[:, 0]
276+
277+
def forward(self, x):
278+
x = self.forward_features(x)
279+
x = self.head(x)
265280
return x
266281

267282

@@ -284,7 +299,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
284299
model.default_cfg = default_cfgs['vit_small_patch16_224']
285300
if pretrained:
286301
load_pretrained(
287-
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
302+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
288303
return model
289304

290305

@@ -297,7 +312,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
297312
model.default_cfg = default_cfgs['vit_base_patch16_224']
298313
if pretrained:
299314
load_pretrained(
300-
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
315+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
301316
return model
302317

303318

@@ -308,8 +323,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
308323
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
309324
model.default_cfg = default_cfgs['vit_base_patch16_384']
310325
if pretrained:
311-
load_pretrained(
312-
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
326+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
313327
return model
314328

315329

@@ -320,8 +334,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
320334
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
321335
model.default_cfg = default_cfgs['vit_base_patch32_384']
322336
if pretrained:
323-
load_pretrained(
324-
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
337+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
325338
return model
326339

327340

@@ -339,8 +352,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
339352
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
340353
model.default_cfg = default_cfgs['vit_large_patch16_384']
341354
if pretrained:
342-
load_pretrained(
343-
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
355+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
344356
return model
345357

346358

@@ -351,8 +363,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
351363
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
352364
model.default_cfg = default_cfgs['vit_large_patch32_384']
353365
if pretrained:
354-
load_pretrained(
355-
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
366+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
356367
return model
357368

358369

timm/optim/optim_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
4343
if weight_decay and filter_bias_and_bn:
4444
skip = {}
4545
if hasattr(model, 'no_weight_decay'):
46-
skip = model.no_weight_decay
46+
skip = model.no_weight_decay()
4747
parameters = add_weight_decay(model, weight_decay, skip)
4848
weight_decay = 0.
4949
else:

0 commit comments

Comments
 (0)