Skip to content

Commit b401952

Browse files
committed
Add newly added vision transformer large/base 224x224 weights ported from JAX official repo
1 parent 61200db commit b401952

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

timm/models/vision_transformer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,18 @@ def _cfg(url='', **kwargs):
4848
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
4949
),
5050
'vit_base_patch16_224': _cfg(
51-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth',
51+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
52+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
5253
),
5354
'vit_base_patch16_384': _cfg(
5455
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
5556
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
5657
'vit_base_patch32_384': _cfg(
5758
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
5859
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
59-
'vit_large_patch16_224': _cfg(),
60+
'vit_large_patch16_224': _cfg(
61+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
62+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
6063
'vit_large_patch16_384': _cfg(
6164
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
6265
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
@@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
305308

306309
@register_model
307310
def vit_base_patch16_224(pretrained=False, **kwargs):
308-
if pretrained:
309-
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
310-
kwargs.setdefault('qk_scale', 768 ** -0.5)
311-
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
311+
model = VisionTransformer(
312+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
313+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
312314
model.default_cfg = default_cfgs['vit_base_patch16_224']
313315
if pretrained:
314316
load_pretrained(
@@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
340342

341343
@register_model
342344
def vit_large_patch16_224(pretrained=False, **kwargs):
343-
model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
345+
model = VisionTransformer(
346+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
347+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
344348
model.default_cfg = default_cfgs['vit_large_patch16_224']
349+
if pretrained:
350+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
345351
return model
346352

347353

0 commit comments

Comments
 (0)