@@ -37,7 +37,7 @@ def _cfg(url='', **kwargs):
3737 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
3838 'crop_pct' : .9 , 'interpolation' : 'bicubic' ,
3939 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
40- 'first_conv' : '' , 'classifier' : 'head' ,
40+ 'first_conv' : 'patch_embed.proj ' , 'classifier' : 'head' ,
4141 ** kwargs
4242 }
4343
@@ -48,15 +48,18 @@ def _cfg(url='', **kwargs):
4848 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth' ,
4949 ),
5050 'vit_base_patch16_224' : _cfg (
51- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth' ,
51+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth' ,
52+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
5253 ),
5354 'vit_base_patch16_384' : _cfg (
5455 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth' ,
5556 input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
5657 'vit_base_patch32_384' : _cfg (
5758 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth' ,
5859 input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
59- 'vit_large_patch16_224' : _cfg (),
60+ 'vit_large_patch16_224' : _cfg (
61+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth' ,
62+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 )),
6063 'vit_large_patch16_384' : _cfg (
6164 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth' ,
6265 input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
@@ -206,7 +209,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
206209 drop_path_rate = 0. , hybrid_backbone = None , norm_layer = nn .LayerNorm ):
207210 super ().__init__ ()
208211 self .num_classes = num_classes
209- self .embed_dim = embed_dim
212+ self .num_features = self . embed_dim = embed_dim # num_features for consistency with other models
210213
211214 if hybrid_backbone is not None :
212215 self .patch_embed = HybridEmbed (
@@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
305308
306309@register_model
307310def vit_base_patch16_224 (pretrained = False , ** kwargs ):
308- if pretrained :
309- # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
310- kwargs .setdefault ('qk_scale' , 768 ** - 0.5 )
311- model = VisionTransformer (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
311+ model = VisionTransformer (
312+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
313+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
312314 model .default_cfg = default_cfgs ['vit_base_patch16_224' ]
313315 if pretrained :
314316 load_pretrained (
@@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
340342
341343@register_model
342344def vit_large_patch16_224 (pretrained = False , ** kwargs ):
343- model = VisionTransformer (patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , ** kwargs )
345+ model = VisionTransformer (
346+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
347+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
344348 model .default_cfg = default_cfgs ['vit_large_patch16_224' ]
349+ if pretrained :
350+ load_pretrained (model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
345351 return model
346352
347353
0 commit comments