Skip to content

Commit af3299b

Browse files
authored
Merge pull request #263 from rwightman/fixes_oct2020
Fixes for upcoming PyPi release
2 parents 4a3df78 + 741572d commit af3299b

File tree

6 files changed

+37
-19
lines changed

6 files changed

+37
-19
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
@pytest.mark.timeout(120)
27-
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
27+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1]))
2828
@pytest.mark.parametrize('batch_size', [1])
2929
def test_model_forward(model_name, batch_size):
3030
"""Run a single forward pass with each model"""

timm/models/helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,12 @@ def build_model_with_cfg(
277277
if pruned:
278278
model = adapt_model_from_file(model, variant)
279279

280+
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
281+
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
280282
if pretrained:
281283
load_pretrained(
282284
model,
283-
num_classes=kwargs.get('num_classes', 0),
284-
in_chans=kwargs.get('in_chans', 3),
285+
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
285286
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
286287

287288
if features:

timm/models/hrnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,7 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
776776
strict = True
777777
if model_kwargs.pop('features_only', False):
778778
model_cls = HighResolutionNetFeatures
779+
model_kwargs['num_classes'] = 0
779780
strict = False
780781

781782
return build_model_with_cfg(

timm/models/layers/create_act.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
from .activations_me import *
77
from .config import is_exportable, is_scriptable, is_no_jit
88

9+
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
10+
# will use native version if present. Eventually, the custom Swish layers will be removed
11+
# and only native 'silu' will be used.
12+
_has_silu = 'silu' in dir(torch.nn.functional)
913

1014
_ACT_FN_DEFAULT = dict(
11-
swish=swish,
15+
silu=F.silu if _has_silu else swish,
16+
swish=F.silu if _has_silu else swish,
1217
mish=mish,
1318
relu=F.relu,
1419
relu6=F.relu6,
@@ -26,23 +31,26 @@
2631
)
2732

2833
_ACT_FN_JIT = dict(
29-
swish=swish_jit,
34+
silu=F.silu if _has_silu else swish_jit,
35+
swish=F.silu if _has_silu else swish_jit,
3036
mish=mish_jit,
3137
hard_sigmoid=hard_sigmoid_jit,
3238
hard_swish=hard_swish_jit,
3339
hard_mish=hard_mish_jit
3440
)
3541

3642
_ACT_FN_ME = dict(
37-
swish=swish_me,
43+
silu=F.silu if _has_silu else swish_me,
44+
swish=F.silu if _has_silu else swish_me,
3845
mish=mish_me,
3946
hard_sigmoid=hard_sigmoid_me,
4047
hard_swish=hard_swish_me,
4148
hard_mish=hard_mish_me,
4249
)
4350

4451
_ACT_LAYER_DEFAULT = dict(
45-
swish=Swish,
52+
silu=nn.SiLU if _has_silu else Swish,
53+
swish=nn.SiLU if _has_silu else Swish,
4654
mish=Mish,
4755
relu=nn.ReLU,
4856
relu6=nn.ReLU6,
@@ -60,15 +68,17 @@
6068
)
6169

6270
_ACT_LAYER_JIT = dict(
63-
swish=SwishJit,
71+
silu=nn.SiLU if _has_silu else SwishJit,
72+
swish=nn.SiLU if _has_silu else SwishJit,
6473
mish=MishJit,
6574
hard_sigmoid=HardSigmoidJit,
6675
hard_swish=HardSwishJit,
6776
hard_mish=HardMishJit
6877
)
6978

7079
_ACT_LAYER_ME = dict(
71-
swish=SwishMe,
80+
silu=nn.SiLU if _has_silu else SwishMe,
81+
swish=nn.SiLU if _has_silu else SwishMe,
7282
mish=MishMe,
7383
hard_sigmoid=HardSigmoidMe,
7484
hard_swish=HardSwishMe,

timm/models/vision_transformer.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _cfg(url='', **kwargs):
3737
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
3838
'crop_pct': .9, 'interpolation': 'bicubic',
3939
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
40-
'first_conv': '', 'classifier': 'head',
40+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
4141
**kwargs
4242
}
4343

@@ -48,15 +48,18 @@ def _cfg(url='', **kwargs):
4848
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
4949
),
5050
'vit_base_patch16_224': _cfg(
51-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth',
51+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
52+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
5253
),
5354
'vit_base_patch16_384': _cfg(
5455
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
5556
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
5657
'vit_base_patch32_384': _cfg(
5758
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
5859
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
59-
'vit_large_patch16_224': _cfg(),
60+
'vit_large_patch16_224': _cfg(
61+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
62+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
6063
'vit_large_patch16_384': _cfg(
6164
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
6265
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
@@ -206,7 +209,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
206209
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
207210
super().__init__()
208211
self.num_classes = num_classes
209-
self.embed_dim = embed_dim
212+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
210213

211214
if hybrid_backbone is not None:
212215
self.patch_embed = HybridEmbed(
@@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
305308

306309
@register_model
307310
def vit_base_patch16_224(pretrained=False, **kwargs):
308-
if pretrained:
309-
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
310-
kwargs.setdefault('qk_scale', 768 ** -0.5)
311-
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
311+
model = VisionTransformer(
312+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
313+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
312314
model.default_cfg = default_cfgs['vit_base_patch16_224']
313315
if pretrained:
314316
load_pretrained(
@@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
340342

341343
@register_model
342344
def vit_large_patch16_224(pretrained=False, **kwargs):
343-
model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
345+
model = VisionTransformer(
346+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
347+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
344348
model.default_cfg = default_cfgs['vit_large_patch16_224']
349+
if pretrained:
350+
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
345351
return model
346352

347353

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.2'
1+
__version__ = '0.3.0'

0 commit comments

Comments
 (0)