99import numpy as np
1010import safetensors
1111from huggingface_hub import ModelCard , ModelCardData
12+ from huggingface_hub .constants import HF_HUB_CACHE
1213from safetensors .numpy import save_file
1314from tokenizers import Tokenizer
1415
@@ -24,6 +25,8 @@ def save_pretrained(
2425 config : dict [str , Any ],
2526 create_model_card : bool = True ,
2627 subfolder : str | None = None ,
28+ weights : np .ndarray | None = None ,
29+ mapping : np .ndarray | None = None ,
2730 ** kwargs : Any ,
2831) -> None :
2932 """
@@ -35,11 +38,20 @@ def save_pretrained(
3538 :param config: A metadata config.
3639 :param create_model_card: Whether to create a model card.
3740 :param subfolder: The subfolder to save the model in.
41+ :param weights: The weights of the model. If None, no weights are saved.
42+ :param mapping: The token mapping of the model. If None, there is no token mapping.
3843 :param **kwargs: Any additional arguments.
3944 """
4045 folder_path = folder_path / subfolder if subfolder else folder_path
4146 folder_path .mkdir (exist_ok = True , parents = True )
42- save_file ({"embeddings" : embeddings }, folder_path / "model.safetensors" )
47+
48+ model_weights = {"embeddings" : embeddings }
49+ if weights is not None :
50+ model_weights ["weights" ] = weights
51+ if mapping is not None :
52+ model_weights ["mapping" ] = mapping
53+
54+ save_file (model_weights , folder_path / "model.safetensors" )
4355 tokenizer .save (str (folder_path / "tokenizer.json" ), pretty = False )
4456 json .dump (config , open (folder_path / "config.json" , "w" ), indent = 4 )
4557
@@ -96,10 +108,11 @@ def _create_model_card(
96108
97109def load_pretrained (
98110 folder_or_repo_path : str | Path ,
99- subfolder : str | None = None ,
100- token : str | None = None ,
101- from_sentence_transformers : bool = False ,
102- ) -> tuple [np .ndarray , Tokenizer , dict [str , Any ], dict [str , Any ]]:
111+ subfolder : str | None ,
112+ token : str | None ,
113+ from_sentence_transformers : bool ,
114+ force_download : bool ,
115+ ) -> tuple [np .ndarray , Tokenizer , dict [str , Any ], dict [str , Any ], np .ndarray | None , np .ndarray | None ]:
103116 """
104117 Loads a pretrained model from a folder.
105118
@@ -109,8 +122,10 @@ def load_pretrained(
109122 :param subfolder: The subfolder to load from.
110123 :param token: The huggingface token to use.
111124 :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
125+ :param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
126+ already present in the cache.
112127 :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113- :return: The embeddings, tokenizer, config, and metadata .
128+ :return: The embeddings, tokenizer, config, metadata, weights and mapping .
114129
115130 """
116131 if from_sentence_transformers :
@@ -122,7 +137,13 @@ def load_pretrained(
122137 tokenizer_file = "tokenizer.json"
123138 config_name = "config.json"
124139
125- folder_or_repo_path = Path (folder_or_repo_path )
140+ cached_folder = _get_latest_model_path (str (folder_or_repo_path ))
141+ if cached_folder and not force_download :
142+ logger .info (f"Found cached model at { cached_folder } , loading from cache." )
143+ folder_or_repo_path = cached_folder
144+ else :
145+ logger .info (f"No cached model found for { folder_or_repo_path } , loading from local or hub." )
146+ folder_or_repo_path = Path (folder_or_repo_path )
126147
127148 local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128149
@@ -139,9 +160,7 @@ def load_pretrained(
139160 if not tokenizer_path .exists ():
140161 raise FileNotFoundError (f"Tokenizer file does not exist in { local_folder } " )
141162
142- # README is optional, so this is a bit finicky.
143163 readme_path = local_folder / "README.md"
144- metadata = _get_metadata_from_readme (readme_path )
145164
146165 else :
147166 logger .info ("Folder does not exist locally, attempting to use huggingface hub." )
@@ -150,18 +169,11 @@ def load_pretrained(
150169 folder_or_repo_path .as_posix (), model_file , token = token , subfolder = subfolder
151170 )
152171 )
153-
154- try :
155- readme_path = Path (
156- huggingface_hub .hf_hub_download (
157- folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
158- )
172+ readme_path = Path (
173+ huggingface_hub .hf_hub_download (
174+ folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
159175 )
160- metadata = _get_metadata_from_readme (Path (readme_path ))
161- except Exception as e :
162- # NOTE: we don't want to raise an error here, since the README is optional.
163- logger .info (f"No README found in the model folder: { e } No model card loaded." )
164- metadata = {}
176+ )
165177
166178 config_path = Path (
167179 huggingface_hub .hf_hub_download (
@@ -175,20 +187,27 @@ def load_pretrained(
175187 )
176188
177189 opened_tensor_file = cast (SafeOpenProtocol , safetensors .safe_open (embeddings_path , framework = "numpy" ))
178- if from_sentence_transformers :
179- embeddings = opened_tensor_file .get_tensor ("embedding.weight" )
190+ embedding_name = "embedding.weight" if from_sentence_transformers else "embeddings"
191+ embeddings = opened_tensor_file .get_tensor (embedding_name )
192+ try :
193+ weights = opened_tensor_file .get_tensor ("weights" )
194+ except Exception :
195+ # Bare except because safetensors does not export its own errors.
196+ weights = None
197+ try :
198+ mapping = opened_tensor_file .get_tensor ("mapping" )
199+ except Exception :
200+ mapping = None
201+
202+ if readme_path .exists ():
203+ metadata = _get_metadata_from_readme (readme_path )
180204 else :
181- embeddings = opened_tensor_file . get_tensor ( "embeddings" )
205+ metadata = {}
182206
183207 tokenizer : Tokenizer = Tokenizer .from_file (str (tokenizer_path ))
184208 config = json .load (open (config_path ))
185209
186- if len (tokenizer .get_vocab ()) != len (embeddings ):
187- logger .warning (
188- f"Number of tokens does not match number of embeddings: `{ len (tokenizer .get_vocab ())} ` vs `{ len (embeddings )} `"
189- )
190-
191- return embeddings , tokenizer , config , metadata
210+ return embeddings , tokenizer , config , metadata , weights , mapping
192211
193212
194213def _get_metadata_from_readme (readme_path : Path ) -> dict [str , Any ]:
@@ -223,3 +242,28 @@ def push_folder_to_hub(
223242 huggingface_hub .upload_folder (repo_id = repo_id , folder_path = folder_path , token = token , path_in_repo = subfolder )
224243
225244 logger .info (f"Pushed model to { repo_id } " )
245+
246+
247+ def _get_latest_model_path (model_id : str ) -> Path | None :
248+ """
249+ Gets the latest model path for a given identifier from the hugging face hub cache.
250+
251+ Returns None if there is no cached model. In this case, the model will be downloaded.
252+ """
253+ # Make path object
254+ cache_dir = Path (HF_HUB_CACHE )
255+ # This is specific to how HF stores the files.
256+ normalized = model_id .replace ("/" , "--" )
257+ repo_dir = cache_dir / f"models--{ normalized } " / "snapshots"
258+
259+ if not repo_dir .exists ():
260+ return None
261+
262+ # Find all directories.
263+ snapshots = [p for p in repo_dir .iterdir () if p .is_dir ()]
264+ if not snapshots :
265+ return None
266+
267+ # Get the latest directory by modification time.
268+ latest_snapshot = max (snapshots , key = lambda p : p .stat ().st_mtime )
269+ return latest_snapshot
0 commit comments