Skip to content

Commit 4d5c395

Browse files
committed
MaxVit, ViT, ConvNeXt, and EfficientNet-v2 updates
* Add support for TF weights and modelling specifics to MaxVit (testing ported weights) * More fine-tuned CLIP ViT configs * ConvNeXt and MaxVit updated to new pretrained cfgs use * EfficientNetV2, MaxVit and ConvNeXt high res models use squash crop/resize
1 parent 3db4e34 commit 4d5c395

File tree

6 files changed

+1259
-963
lines changed

6 files changed

+1259
-963
lines changed

timm/models/convnext.py

Lines changed: 104 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -21,111 +21,13 @@
2121
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
2222
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
2323
create_conv2d, get_act_layer, make_divisible, to_ntuple
24+
from ._pretrained import generate_defaults
2425
from .registry import register_model
2526

2627

2728
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
2829

2930

30-
def _cfg(url='', **kwargs):
31-
return {
32-
'url': url,
33-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
34-
'crop_pct': 0.875, 'interpolation': 'bicubic',
35-
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
36-
'first_conv': 'stem.0', 'classifier': 'head.fc',
37-
**kwargs
38-
}
39-
40-
41-
default_cfgs = dict(
42-
# timm specific variants
43-
convnext_atto=_cfg(
44-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
45-
test_input_size=(3, 288, 288), test_crop_pct=0.95),
46-
convnext_atto_ols=_cfg(
47-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
48-
test_input_size=(3, 288, 288), test_crop_pct=0.95),
49-
convnext_femto=_cfg(
50-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
51-
test_input_size=(3, 288, 288), test_crop_pct=0.95),
52-
convnext_femto_ols=_cfg(
53-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
54-
test_input_size=(3, 288, 288), test_crop_pct=0.95),
55-
convnext_pico=_cfg(
56-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
57-
test_input_size=(3, 288, 288), test_crop_pct=0.95),
58-
convnext_pico_ols=_cfg(
59-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
60-
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
61-
convnext_nano=_cfg(
62-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
63-
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
64-
convnext_nano_ols=_cfg(
65-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
66-
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
67-
convnext_tiny_hnf=_cfg(
68-
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
69-
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
70-
71-
convnext_tiny=_cfg(
72-
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
73-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
74-
convnext_small=_cfg(
75-
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
76-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
77-
convnext_base=_cfg(
78-
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
79-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
80-
convnext_large=_cfg(
81-
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
82-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
83-
84-
convnext_tiny_in22ft1k=_cfg(
85-
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
86-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
87-
convnext_small_in22ft1k=_cfg(
88-
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
89-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
90-
convnext_base_in22ft1k=_cfg(
91-
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
92-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
93-
convnext_large_in22ft1k=_cfg(
94-
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
95-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
96-
convnext_xlarge_in22ft1k=_cfg(
97-
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
98-
test_input_size=(3, 288, 288), test_crop_pct=1.0),
99-
100-
convnext_tiny_384_in22ft1k=_cfg(
101-
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
102-
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
103-
convnext_small_384_in22ft1k=_cfg(
104-
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
105-
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
106-
convnext_base_384_in22ft1k=_cfg(
107-
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
108-
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
109-
convnext_large_384_in22ft1k=_cfg(
110-
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
111-
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
112-
convnext_xlarge_384_in22ft1k=_cfg(
113-
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
114-
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
115-
116-
convnext_tiny_in22k=_cfg(
117-
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
118-
convnext_small_in22k=_cfg(
119-
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
120-
convnext_base_in22k=_cfg(
121-
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
122-
convnext_large_in22k=_cfg(
123-
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
124-
convnext_xlarge_in22k=_cfg(
125-
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
126-
)
127-
128-
12931
class ConvNeXtBlock(nn.Module):
13032
""" ConvNeXt Block
13133
There are two equivalent implementations:
@@ -459,6 +361,107 @@ def _create_convnext(variant, pretrained=False, **kwargs):
459361
return model
460362

461363

364+
365+
def _cfg(url='', **kwargs):
366+
return {
367+
'url': url,
368+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
369+
'crop_pct': 0.875, 'interpolation': 'bicubic',
370+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
371+
'first_conv': 'stem.0', 'classifier': 'head.fc',
372+
**kwargs
373+
}
374+
375+
376+
default_cfgs = generate_defaults({
377+
# timm specific variants
378+
'convnext_atto.timm_in1k': _cfg(
379+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
380+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
381+
'convnext_atto_ols.timm_in1k': _cfg(
382+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
383+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
384+
'convnext_femto.timm_in1k': _cfg(
385+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
386+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
387+
'convnext_femto_ols.timm_in1k': _cfg(
388+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
389+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
390+
'convnext_pico.timm_in1k': _cfg(
391+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
392+
test_input_size=(3, 288, 288), test_crop_pct=0.95),
393+
'convnext_pico_ols.timm_in1k': _cfg(
394+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
395+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
396+
'convnext_nano.timm_in1k': _cfg(
397+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
398+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
399+
'convnext_nano_ols.timm_in1k': _cfg(
400+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
401+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
402+
'convnext_tiny_hnf.timm_in1k': _cfg(
403+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
404+
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
405+
406+
'convnext_tiny.fb_in1k': _cfg(
407+
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
408+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
409+
'convnext_small.fb_in1k': _cfg(
410+
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
411+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
412+
'convnext_base.fb_in1k': _cfg(
413+
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
414+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
415+
'convnext_large.fb_in1k': _cfg(
416+
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
417+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
418+
'convnext_xlarge.untrained': _cfg(),
419+
420+
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
421+
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
422+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
423+
'convnext_small.fb_in22k_ft_in1k': _cfg(
424+
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
425+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
426+
'convnext_base.fb_in22k_ft_in1k': _cfg(
427+
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
428+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
429+
'convnext_large.fb_in22k_ft_in1k': _cfg(
430+
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
431+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
432+
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
433+
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
434+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
435+
436+
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
437+
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
438+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
439+
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
440+
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
441+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
442+
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
443+
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
444+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
445+
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
446+
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
447+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
448+
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
449+
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
450+
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
451+
452+
'convnext_tiny_in22k.fb_in22k': _cfg(
453+
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
454+
'convnext_small_in22k.fb_in22k': _cfg(
455+
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
456+
'convnext_base_in22k.fb_in22k': _cfg(
457+
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
458+
'convnext_large_in22k.fb_in22k': _cfg(
459+
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
460+
'convnext_xlarge_in22k.fb_in22k': _cfg(
461+
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
462+
})
463+
464+
462465
@register_model
463466
def convnext_atto(pretrained=False, **kwargs):
464467
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@@ -569,105 +572,7 @@ def convnext_large(pretrained=False, **kwargs):
569572

570573

571574
@register_model
572-
def convnext_tiny_in22ft1k(pretrained=False, **kwargs):
573-
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
574-
model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args)
575-
return model
576-
577-
578-
@register_model
579-
def convnext_small_in22ft1k(pretrained=False, **kwargs):
580-
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
581-
model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args)
582-
return model
583-
584-
585-
@register_model
586-
def convnext_base_in22ft1k(pretrained=False, **kwargs):
587-
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
588-
model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args)
589-
return model
590-
591-
592-
@register_model
593-
def convnext_large_in22ft1k(pretrained=False, **kwargs):
594-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
595-
model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args)
596-
return model
597-
598-
599-
@register_model
600-
def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
601-
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
602-
model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args)
603-
return model
604-
605-
606-
@register_model
607-
def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs):
608-
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
609-
model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args)
610-
return model
611-
612-
613-
@register_model
614-
def convnext_small_384_in22ft1k(pretrained=False, **kwargs):
615-
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
616-
model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args)
617-
return model
618-
619-
620-
@register_model
621-
def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
622-
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
623-
model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args)
624-
return model
625-
626-
627-
@register_model
628-
def convnext_large_384_in22ft1k(pretrained=False, **kwargs):
629-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
630-
model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args)
631-
return model
632-
633-
634-
@register_model
635-
def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
636-
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
637-
model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args)
638-
return model
639-
640-
641-
@register_model
642-
def convnext_tiny_in22k(pretrained=False, **kwargs):
643-
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
644-
model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args)
645-
return model
646-
647-
648-
@register_model
649-
def convnext_small_in22k(pretrained=False, **kwargs):
650-
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
651-
model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args)
652-
return model
653-
654-
655-
@register_model
656-
def convnext_base_in22k(pretrained=False, **kwargs):
657-
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
658-
model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args)
659-
return model
660-
661-
662-
@register_model
663-
def convnext_large_in22k(pretrained=False, **kwargs):
664-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
665-
model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)
666-
return model
667-
668-
669-
@register_model
670-
def convnext_xlarge_in22k(pretrained=False, **kwargs):
575+
def convnext_xlarge(pretrained=False, **kwargs):
671576
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
672-
model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
577+
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
673578
return model

0 commit comments

Comments
 (0)