Skip to content

Commit 72cfa57

Browse files
committed
Add ported Tensorflow MaxVit weights. Add a few more CLIP ViT fine-tunes. Tweak some model tag names. Improve model tag name sorting. Update HF hub push config layout.
1 parent dbe7531 commit 72cfa57

File tree

10 files changed

+262
-157
lines changed

10 files changed

+262
-157
lines changed

timm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .version import __version__
2-
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
2+
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
33
is_scriptable, is_exportable, set_scriptable, set_exportable, \
44
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

timm/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,6 @@
7070
from .layers import convert_splitbn_model, convert_sync_batchnorm
7171
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
7272
from .layers import set_fast_norm
73-
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
74-
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
73+
from ._pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
74+
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
75+
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

timm/models/_pretrained.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import copy
12
from collections import deque, defaultdict
2-
from dataclasses import dataclass, field, replace
3+
from dataclasses import dataclass, field, replace, asdict
34
from typing import Any, Deque, Dict, Tuple, Optional, Union
45

56

@@ -8,13 +9,13 @@ class PretrainedCfg:
89
"""
910
"""
1011
# weight locations
11-
url: str = ''
12-
file: str = ''
13-
hf_hub_id: str = ''
14-
hf_hub_filename: str = ''
12+
url: Optional[Union[str, Tuple[str, str]]] = None
13+
file: Optional[str] = None
14+
hf_hub_id: Optional[str] = None
15+
hf_hub_filename: Optional[str] = None
1516

16-
source: str = '' # source of cfg / weight location used (url, file, hf-hub)
17-
architecture: str = '' # architecture variant can be set when not implicit
17+
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
18+
architecture: Optional[str] = None # architecture variant can be set when not implicit
1819
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
1920

2021
# input / data config
@@ -31,22 +32,40 @@ class PretrainedCfg:
3132

3233
# head config
3334
num_classes: int = 1000
34-
label_offset: int = 0
35+
label_offset: Optional[int] = None
3536

3637
# model attributes that vary with above or required for pretrained adaptation
3738
pool_size: Optional[Tuple[int, ...]] = None
3839
test_pool_size: Optional[Tuple[int, ...]] = None
39-
first_conv: str = ''
40-
classifier: str = ''
40+
first_conv: Optional[str] = None
41+
classifier: Optional[str] = None
4142

42-
license: str = ''
43-
source_url: str = ''
44-
paper: str = ''
45-
notes: str = ''
43+
license: Optional[str] = None
44+
source_url: Optional[str] = None
45+
paper: Optional[str] = None
46+
notes: Optional[str] = None
4647

4748
@property
4849
def has_weights(self):
49-
return self.url.startswith('http') or self.file or self.hf_hub_id
50+
return self.url or self.file or self.hf_hub_id
51+
52+
def to_dict(self, remove_source=False, remove_null=True):
53+
return filter_pretrained_cfg(
54+
asdict(self),
55+
remove_source=remove_source,
56+
remove_null=remove_null
57+
)
58+
59+
60+
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
61+
filtered_cfg = {}
62+
for k, v in cfg.items():
63+
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
64+
continue
65+
if remove_null and v is None:
66+
continue
67+
filtered_cfg[k] = v
68+
return filtered_cfg
5069

5170

5271
@dataclass
@@ -71,7 +90,7 @@ def split_model_name_tag(model_name: str, no_tag=''):
7190
return model_name, tag
7291

7392

74-
def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
93+
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
7594
out = defaultdict(DefaultCfg)
7695
default_set = set() # no tag and tags ending with * are prioritized as default
7796

@@ -82,21 +101,22 @@ def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
82101

83102
model, tag = split_model_name_tag(k)
84103
is_default_set = model in default_set
85-
priority = not tag or (tag.endswith('*') and not is_default_set)
104+
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
86105
tag = tag.strip('*')
87106

88107
default_cfg = out[model]
89-
if has_weights:
90-
default_cfg.is_pretrained = True
91108

92109
if priority:
93110
default_cfg.tags.appendleft(tag)
94111
default_set.add(model)
95-
elif has_weights and not default_set:
112+
elif has_weights and not default_cfg.is_pretrained:
96113
default_cfg.tags.appendleft(tag)
97114
else:
98115
default_cfg.tags.append(tag)
99116

117+
if has_weights:
118+
default_cfg.is_pretrained = True
119+
100120
default_cfg.cfgs[tag] = v
101121

102122
return out

timm/models/convnext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
2222
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
2323
create_conv2d, get_act_layer, make_divisible, to_ntuple
24-
from ._pretrained import generate_defaults
24+
from ._pretrained import generate_default_cfgs
2525
from .registry import register_model
2626

2727

@@ -373,7 +373,7 @@ def _cfg(url='', **kwargs):
373373
}
374374

375375

376-
default_cfgs = generate_defaults({
376+
default_cfgs = generate_default_cfgs({
377377
# timm specific variants
378378
'convnext_atto.timm_in1k': _cfg(
379379
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',

timm/models/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def build_model_with_cfg(
575575
)
576576

577577
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
578-
pretrained_cfg = dataclasses.asdict(pretrained_cfg)
578+
pretrained_cfg = pretrained_cfg.to_dict()
579579

580580
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
581581

timm/models/hub.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from torch.hub import _get_torch_home as get_dir
1616

1717
from timm import __version__
18+
from timm.models._pretrained import filter_pretrained_cfg
1819

1920
try:
20-
from huggingface_hub import (create_repo, get_hf_file_metadata,
21-
hf_hub_download, hf_hub_url,
22-
repo_type_and_id_from_hf_id, upload_folder)
21+
from huggingface_hub import (
22+
create_repo, get_hf_file_metadata,
23+
hf_hub_download, hf_hub_url,
24+
repo_type_and_id_from_hf_id, upload_folder)
2325
from huggingface_hub.utils import EntryNotFoundError
2426
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
2527
_has_hf_hub = True
@@ -46,8 +48,11 @@ def get_cache_dir(child_dir=''):
4648

4749

4850
def download_cached_file(url, check_hash=True, progress=False):
49-
parts = urlparse(url)
50-
filename = os.path.basename(parts.path)
51+
if isinstance(url, (list, tuple)):
52+
url, filename = url
53+
else:
54+
parts = urlparse(url)
55+
filename = os.path.basename(parts.path)
5156
cached_file = os.path.join(get_cache_dir(), filename)
5257
if not os.path.exists(cached_file):
5358
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
@@ -90,10 +95,27 @@ def _download_from_hf(model_id: str, filename: str):
9095
def load_model_config_from_hf(model_id: str):
9196
assert has_hf_hub(True)
9297
cached_file = _download_from_hf(model_id, 'config.json')
93-
pretrained_cfg = load_cfg_from_json(cached_file)
98+
99+
hf_config = load_cfg_from_json(cached_file)
100+
if 'pretrained_cfg' not in hf_config:
101+
# old form, pull pretrain_cfg out of the base dict
102+
pretrained_cfg = hf_config
103+
hf_config = {}
104+
hf_config['architecture'] = pretrained_cfg.pop('architecture')
105+
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
106+
if 'labels' in pretrained_cfg:
107+
hf_config['label_name'] = pretrained_cfg.pop('labels')
108+
hf_config['pretrained_cfg'] = pretrained_cfg
109+
110+
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
111+
pretrained_cfg = hf_config['pretrained_cfg']
94112
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
95113
pretrained_cfg['source'] = 'hf-hub'
96-
model_name = pretrained_cfg.get('architecture')
114+
if 'num_classes' in hf_config:
115+
# model should be created with parent num_classes if they exist
116+
pretrained_cfg['num_classes'] = hf_config['num_classes']
117+
model_name = hf_config['architecture']
118+
97119
return pretrained_cfg, model_name
98120

99121

@@ -114,10 +136,34 @@ def save_for_hf(model, save_directory, model_config=None):
114136
torch.save(model.state_dict(), weights_path)
115137

116138
config_path = save_directory / 'config.json'
117-
hf_config = model.pretrained_cfg
118-
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
119-
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
120-
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
139+
hf_config = {}
140+
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
141+
# set some values at root config level
142+
hf_config['architecture'] = pretrained_cfg.pop('architecture')
143+
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
144+
hf_config['num_features'] = model_config.get('num_features', model.num_features)
145+
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
146+
147+
if 'label' in model_config:
148+
_logger.warning(
149+
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
150+
"Using provided 'label' field as 'label_name'.")
151+
model_config['label_name'] = model_config.pop('label')
152+
153+
label_name = model_config.pop('label_name', None)
154+
if label_name:
155+
assert isinstance(label_name, (dict, list, tuple))
156+
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
157+
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
158+
hf_config['label_name'] = model_config['label_name']
159+
160+
display_name = model_config.pop('display_name', None)
161+
if display_name:
162+
assert isinstance(display_name, dict)
163+
# map label_name -> user interface display name
164+
hf_config['display_name'] = model_config['display_name']
165+
166+
hf_config['pretrained_cfg'] = pretrained_cfg
121167
hf_config.update(model_config)
122168

123169
with config_path.open('w') as f:
@@ -127,14 +173,14 @@ def save_for_hf(model, save_directory, model_config=None):
127173
def push_to_hf_hub(
128174
model,
129175
repo_id: str,
130-
commit_message: str ='Add model',
176+
commit_message: str = 'Add model',
131177
token: Optional[str] = None,
132178
revision: Optional[str] = None,
133179
private: bool = False,
134180
create_pr: bool = False,
135181
model_config: Optional[dict] = None,
136182
):
137-
# Create repo if doesn't exist yet
183+
# Create repo if it doesn't exist yet
138184
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
139185

140186
# Infer complete repo_id from repo_url
@@ -154,10 +200,11 @@ def push_to_hf_hub(
154200
# Save model weights and config.
155201
save_for_hf(model, tmpdir, model_config=model_config)
156202

157-
# Add readme if does not exist
203+
# Add readme if it does not exist
158204
if not has_readme:
205+
model_name = repo_id.split('/')[-1]
159206
readme_path = Path(tmpdir) / "README.md"
160-
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
207+
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
161208
readme_path.write_text(readme_text)
162209

163210
# Upload model and return

0 commit comments

Comments
 (0)