@@ -48,15 +48,18 @@ def _cfg(url='', **kwargs):
4848 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth' ,
4949 ),
5050 'vit_base_patch16_224' : _cfg (
51- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth' ,
51+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth' ,
52+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
5253 ),
5354 'vit_base_patch16_384' : _cfg (
5455 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth' ,
5556 input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
5657 'vit_base_patch32_384' : _cfg (
5758 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth' ,
5859 input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
59- 'vit_large_patch16_224' : _cfg (),
60+ 'vit_large_patch16_224' : _cfg (
61+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth' ,
62+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 )),
6063 'vit_large_patch16_384' : _cfg (
6164 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth' ,
6265 input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
@@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
305308
306309@register_model
307310def vit_base_patch16_224 (pretrained = False , ** kwargs ):
308- if pretrained :
309- # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
310- kwargs .setdefault ('qk_scale' , 768 ** - 0.5 )
311- model = VisionTransformer (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
311+ model = VisionTransformer (
312+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
313+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
312314 model .default_cfg = default_cfgs ['vit_base_patch16_224' ]
313315 if pretrained :
314316 load_pretrained (
@@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
340342
341343@register_model
342344def vit_large_patch16_224 (pretrained = False , ** kwargs ):
343- model = VisionTransformer (patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , ** kwargs )
345+ model = VisionTransformer (
346+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
347+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
344348 model .default_cfg = default_cfgs ['vit_large_patch16_224' ]
349+ if pretrained :
350+ load_pretrained (model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
345351 return model
346352
347353
0 commit comments