1515 from torch .hub import _get_torch_home as get_dir
1616
1717from timm import __version__
18+ from timm .models ._pretrained import filter_pretrained_cfg
1819
1920try :
20- from huggingface_hub import (create_repo , get_hf_file_metadata ,
21- hf_hub_download , hf_hub_url ,
22- repo_type_and_id_from_hf_id , upload_folder )
21+ from huggingface_hub import (
22+ create_repo , get_hf_file_metadata ,
23+ hf_hub_download , hf_hub_url ,
24+ repo_type_and_id_from_hf_id , upload_folder )
2325 from huggingface_hub .utils import EntryNotFoundError
2426 hf_hub_download = partial (hf_hub_download , library_name = "timm" , library_version = __version__ )
2527 _has_hf_hub = True
@@ -46,8 +48,11 @@ def get_cache_dir(child_dir=''):
4648
4749
4850def download_cached_file (url , check_hash = True , progress = False ):
49- parts = urlparse (url )
50- filename = os .path .basename (parts .path )
51+ if isinstance (url , (list , tuple )):
52+ url , filename = url
53+ else :
54+ parts = urlparse (url )
55+ filename = os .path .basename (parts .path )
5156 cached_file = os .path .join (get_cache_dir (), filename )
5257 if not os .path .exists (cached_file ):
5358 _logger .info ('Downloading: "{}" to {}\n ' .format (url , cached_file ))
@@ -90,10 +95,27 @@ def _download_from_hf(model_id: str, filename: str):
9095def load_model_config_from_hf (model_id : str ):
9196 assert has_hf_hub (True )
9297 cached_file = _download_from_hf (model_id , 'config.json' )
93- pretrained_cfg = load_cfg_from_json (cached_file )
98+
99+ hf_config = load_cfg_from_json (cached_file )
100+ if 'pretrained_cfg' not in hf_config :
101+ # old form, pull pretrain_cfg out of the base dict
102+ pretrained_cfg = hf_config
103+ hf_config = {}
104+ hf_config ['architecture' ] = pretrained_cfg .pop ('architecture' )
105+ hf_config ['num_features' ] = pretrained_cfg .pop ('num_features' , None )
106+ if 'labels' in pretrained_cfg :
107+ hf_config ['label_name' ] = pretrained_cfg .pop ('labels' )
108+ hf_config ['pretrained_cfg' ] = pretrained_cfg
109+
110+ # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
111+ pretrained_cfg = hf_config ['pretrained_cfg' ]
94112 pretrained_cfg ['hf_hub_id' ] = model_id # insert hf_hub id for pretrained weight load during model creation
95113 pretrained_cfg ['source' ] = 'hf-hub'
96- model_name = pretrained_cfg .get ('architecture' )
114+ if 'num_classes' in hf_config :
115+ # model should be created with parent num_classes if they exist
116+ pretrained_cfg ['num_classes' ] = hf_config ['num_classes' ]
117+ model_name = hf_config ['architecture' ]
118+
97119 return pretrained_cfg , model_name
98120
99121
@@ -114,10 +136,34 @@ def save_for_hf(model, save_directory, model_config=None):
114136 torch .save (model .state_dict (), weights_path )
115137
116138 config_path = save_directory / 'config.json'
117- hf_config = model .pretrained_cfg
118- hf_config ['num_classes' ] = model_config .pop ('num_classes' , model .num_classes )
119- hf_config ['num_features' ] = model_config .pop ('num_features' , model .num_features )
120- hf_config ['labels' ] = model_config .pop ('labels' , [f"LABEL_{ i } " for i in range (hf_config ['num_classes' ])])
139+ hf_config = {}
140+ pretrained_cfg = filter_pretrained_cfg (model .pretrained_cfg , remove_source = True , remove_null = True )
141+ # set some values at root config level
142+ hf_config ['architecture' ] = pretrained_cfg .pop ('architecture' )
143+ hf_config ['num_classes' ] = model_config .get ('num_classes' , model .num_classes )
144+ hf_config ['num_features' ] = model_config .get ('num_features' , model .num_features )
145+ hf_config ['global_pool' ] = model_config .get ('global_pool' , getattr (model , 'global_pool' , None ))
146+
147+ if 'label' in model_config :
148+ _logger .warning (
149+ "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
150+ "Using provided 'label' field as 'label_name'." )
151+ model_config ['label_name' ] = model_config .pop ('label' )
152+
153+ label_name = model_config .pop ('label_name' , None )
154+ if label_name :
155+ assert isinstance (label_name , (dict , list , tuple ))
156+ # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
157+ # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
158+ hf_config ['label_name' ] = model_config ['label_name' ]
159+
160+ display_name = model_config .pop ('display_name' , None )
161+ if display_name :
162+ assert isinstance (display_name , dict )
163+ # map label_name -> user interface display name
164+ hf_config ['display_name' ] = model_config ['display_name' ]
165+
166+ hf_config ['pretrained_cfg' ] = pretrained_cfg
121167 hf_config .update (model_config )
122168
123169 with config_path .open ('w' ) as f :
@@ -127,14 +173,14 @@ def save_for_hf(model, save_directory, model_config=None):
127173def push_to_hf_hub (
128174 model ,
129175 repo_id : str ,
130- commit_message : str = 'Add model' ,
176+ commit_message : str = 'Add model' ,
131177 token : Optional [str ] = None ,
132178 revision : Optional [str ] = None ,
133179 private : bool = False ,
134180 create_pr : bool = False ,
135181 model_config : Optional [dict ] = None ,
136182):
137- # Create repo if doesn't exist yet
183+ # Create repo if it doesn't exist yet
138184 repo_url = create_repo (repo_id , token = token , private = private , exist_ok = True )
139185
140186 # Infer complete repo_id from repo_url
@@ -154,10 +200,11 @@ def push_to_hf_hub(
154200 # Save model weights and config.
155201 save_for_hf (model , tmpdir , model_config = model_config )
156202
157- # Add readme if does not exist
203+ # Add readme if it does not exist
158204 if not has_readme :
205+ model_name = repo_id .split ('/' )[- 1 ]
159206 readme_path = Path (tmpdir ) / "README.md"
160- readme_text = f'---\n tags:\n - image-classification\n - timm\n library_tag: timm\n ---\n # Model card for { repo_id } '
207+ readme_text = f'---\n tags:\n - image-classification\n - timm\n library_tag: timm\n ---\n # Model card for { model_name } '
161208 readme_path .write_text (readme_text )
162209
163210 # Upload model and return
0 commit comments