@@ -107,7 +107,8 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
107107
108108 def forward (self , x ):
109109 B , N , C = x .shape
110- q , k , v = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (2 , 0 , 3 , 1 , 4 )
110+ qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (2 , 0 , 3 , 1 , 4 )
111+ q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # make torchscript happy (cannot use tensor as tuple)
111112
112113 attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
113114 attn = attn .softmax (dim = - 1 )
@@ -204,6 +205,9 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
204205 num_heads = 12 , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop_rate = 0. , attn_drop_rate = 0. ,
205206 drop_path_rate = 0. , hybrid_backbone = None , norm_layer = nn .LayerNorm ):
206207 super ().__init__ ()
208+ self .num_classes = num_classes
209+ self .embed_dim = embed_dim
210+
207211 if hybrid_backbone is not None :
208212 self .patch_embed = HybridEmbed (
209213 hybrid_backbone , img_size = img_size , in_chans = in_chans , embed_dim = embed_dim )
@@ -229,7 +233,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
229233 #self.repr_act = nn.Tanh()
230234
231235 # Classifier head
232- self .head = nn .Linear (embed_dim , num_classes )
236+ self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn . Identity ()
233237
234238 trunc_normal_ (self .pos_embed , std = .02 )
235239 trunc_normal_ (self .cls_token , std = .02 )
@@ -244,11 +248,18 @@ def _init_weights(self, m):
244248 nn .init .constant_ (m .bias , 0 )
245249 nn .init .constant_ (m .weight , 1.0 )
246250
247- @property
251+ @torch . jit . ignore
248252 def no_weight_decay (self ):
249253 return {'pos_embed' , 'cls_token' }
250254
251- def forward (self , x ):
255+ def get_classifier (self ):
256+ return self .head
257+
258+ def reset_classifier (self , num_classes , global_pool = '' ):
259+ self .num_classes = num_classes
260+ self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
261+
262+ def forward_features (self , x ):
252263 B = x .shape [0 ]
253264 x = self .patch_embed (x )
254265
@@ -261,7 +272,11 @@ def forward(self, x):
261272 x = blk (x )
262273
263274 x = self .norm (x )
264- x = self .head (x [:, 0 ])
275+ return x [:, 0 ]
276+
277+ def forward (self , x ):
278+ x = self .forward_features (x )
279+ x = self .head (x )
265280 return x
266281
267282
@@ -284,7 +299,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
284299 model .default_cfg = default_cfgs ['vit_small_patch16_224' ]
285300 if pretrained :
286301 load_pretrained (
287- model , num_classes = kwargs . get ( ' num_classes' , 0 ) , in_chans = kwargs .get ('in_chans' , 3 ), filter_fn = _conv_filter )
302+ model , num_classes = model . num_classes , in_chans = kwargs .get ('in_chans' , 3 ), filter_fn = _conv_filter )
288303 return model
289304
290305
@@ -297,7 +312,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
297312 model .default_cfg = default_cfgs ['vit_base_patch16_224' ]
298313 if pretrained :
299314 load_pretrained (
300- model , num_classes = kwargs . get ( ' num_classes' , 0 ) , in_chans = kwargs .get ('in_chans' , 3 ), filter_fn = _conv_filter )
315+ model , num_classes = model . num_classes , in_chans = kwargs .get ('in_chans' , 3 ), filter_fn = _conv_filter )
301316 return model
302317
303318
@@ -308,8 +323,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
308323 norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
309324 model .default_cfg = default_cfgs ['vit_base_patch16_384' ]
310325 if pretrained :
311- load_pretrained (
312- model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
326+ load_pretrained (model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
313327 return model
314328
315329
@@ -320,8 +334,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
320334 norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
321335 model .default_cfg = default_cfgs ['vit_base_patch32_384' ]
322336 if pretrained :
323- load_pretrained (
324- model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
337+ load_pretrained (model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
325338 return model
326339
327340
@@ -339,8 +352,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
339352 norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
340353 model .default_cfg = default_cfgs ['vit_large_patch16_384' ]
341354 if pretrained :
342- load_pretrained (
343- model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
355+ load_pretrained (model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
344356 return model
345357
346358
@@ -351,8 +363,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
351363 norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
352364 model .default_cfg = default_cfgs ['vit_large_patch32_384' ]
353365 if pretrained :
354- load_pretrained (
355- model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
366+ load_pretrained (model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
356367 return model
357368
358369
0 commit comments