99import numpy as np
1010import safetensors
1111from huggingface_hub import ModelCard , ModelCardData
12+ from huggingface_hub .constants import HF_HUB_CACHE
1213from safetensors .numpy import save_file
1314from tokenizers import Tokenizer
1415
@@ -96,9 +97,10 @@ def _create_model_card(
9697
9798def load_pretrained (
9899 folder_or_repo_path : str | Path ,
99- subfolder : str | None = None ,
100- token : str | None = None ,
101- from_sentence_transformers : bool = False ,
100+ subfolder : str | None ,
101+ token : str | None ,
102+ from_sentence_transformers : bool ,
103+ force_download : bool ,
102104) -> tuple [np .ndarray , Tokenizer , dict [str , Any ], dict [str , Any ]]:
103105 """
104106 Loads a pretrained model from a folder.
@@ -109,6 +111,8 @@ def load_pretrained(
109111 :param subfolder: The subfolder to load from.
110112 :param token: The huggingface token to use.
111113 :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
114+ :param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
115+ already present in the cache.
112116 :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113117 :return: The embeddings, tokenizer, config, and metadata.
114118
@@ -122,7 +126,13 @@ def load_pretrained(
122126 tokenizer_file = "tokenizer.json"
123127 config_name = "config.json"
124128
125- folder_or_repo_path = Path (folder_or_repo_path )
129+ cached_folder = _get_latest_model_path (str (folder_or_repo_path ))
130+ if cached_folder and not force_download :
131+ logger .info (f"Found cached model at { cached_folder } , loading from cache." )
132+ folder_or_repo_path = cached_folder
133+ else :
134+ logger .info (f"No cached model found for { folder_or_repo_path } , loading from local or hub." )
135+ folder_or_repo_path = Path (folder_or_repo_path )
126136
127137 local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128138
@@ -139,9 +149,7 @@ def load_pretrained(
139149 if not tokenizer_path .exists ():
140150 raise FileNotFoundError (f"Tokenizer file does not exist in { local_folder } " )
141151
142- # README is optional, so this is a bit finicky.
143152 readme_path = local_folder / "README.md"
144- metadata = _get_metadata_from_readme (readme_path )
145153
146154 else :
147155 logger .info ("Folder does not exist locally, attempting to use huggingface hub." )
@@ -150,18 +158,11 @@ def load_pretrained(
150158 folder_or_repo_path .as_posix (), model_file , token = token , subfolder = subfolder
151159 )
152160 )
153-
154- try :
155- readme_path = Path (
156- huggingface_hub .hf_hub_download (
157- folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
158- )
161+ readme_path = Path (
162+ huggingface_hub .hf_hub_download (
163+ folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
159164 )
160- metadata = _get_metadata_from_readme (Path (readme_path ))
161- except Exception as e :
162- # NOTE: we don't want to raise an error here, since the README is optional.
163- logger .info (f"No README found in the model folder: { e } No model card loaded." )
164- metadata = {}
165+ )
165166
166167 config_path = Path (
167168 huggingface_hub .hf_hub_download (
@@ -175,10 +176,13 @@ def load_pretrained(
175176 )
176177
177178 opened_tensor_file = cast (SafeOpenProtocol , safetensors .safe_open (embeddings_path , framework = "numpy" ))
178- if from_sentence_transformers :
179- embeddings = opened_tensor_file .get_tensor ("embedding.weight" )
179+ embedding_key = "embedding.weight" if from_sentence_transformers else "embeddings"
180+ embeddings = opened_tensor_file .get_tensor (embedding_key )
181+
182+ if readme_path .exists ():
183+ metadata = _get_metadata_from_readme (readme_path )
180184 else :
181- embeddings = opened_tensor_file . get_tensor ( "embeddings" )
185+ metadata = {}
182186
183187 tokenizer : Tokenizer = Tokenizer .from_file (str (tokenizer_path ))
184188 config = json .load (open (config_path ))
@@ -223,3 +227,28 @@ def push_folder_to_hub(
223227 huggingface_hub .upload_folder (repo_id = repo_id , folder_path = folder_path , token = token , path_in_repo = subfolder )
224228
225229 logger .info (f"Pushed model to { repo_id } " )
230+
231+
232+ def _get_latest_model_path (model_id : str ) -> Path | None :
233+ """
234+ Gets the latest model path for a given identifier from the hugging face hub cache.
235+
236+ Returns None if there is no cached model. In this case, the model will be downloaded.
237+ """
238+ # Make path object
239+ cache_dir = Path (HF_HUB_CACHE )
240+ # This is specific to how HF stores the files.
241+ normalized = model_id .replace ("/" , "--" )
242+ repo_dir = cache_dir / f"models--{ normalized } " / "snapshots"
243+
244+ if not repo_dir .exists ():
245+ return None
246+
247+ # Find all directories.
248+ snapshots = [p for p in repo_dir .iterdir () if p .is_dir ()]
249+ if not snapshots :
250+ return None
251+
252+ # Get the latest directory by modification time.
253+ latest_snapshot = max (snapshots , key = lambda p : p .stat ().st_mtime )
254+ return latest_snapshot
0 commit comments