Skip to content

Commit 13095c9

Browse files
authored
feat: faster loading if model already cached (#278)
* feat: faster loading if model already cached * fix: add force download and remove readme stuff * switch force_download to True by default * remove defaults in submodule
1 parent 55b955a commit 13095c9

File tree

7 files changed

+1120
-1012
lines changed

7 files changed

+1120
-1012
lines changed

model2vec/distill/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_embeddings(
4646
:param pad_token_id: The pad token id. Used to pad sequences.
4747
:return: The output embeddings.
4848
"""
49-
model = model.to(device)
49+
model = model.to(device) # type: ignore
5050

5151
out_weights: np.ndarray
5252
intermediate_weights: list[np.ndarray] = []
@@ -98,7 +98,7 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
9898
"""
9999
encodings = {k: v.to(model.device) for k, v in encodings.items()}
100100
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101-
out: torch.Tensor = encoded.last_hidden_state.cpu()
101+
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # typing is wrong.
102102
# NOTE: If the dtype is bfloat 16, we convert to float32,
103103
# because numpy does not suport bfloat16
104104
# See here: https://github.com/numpy/numpy/issues/19808

model2vec/hf_utils.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import safetensors
1111
from huggingface_hub import ModelCard, ModelCardData
12+
from huggingface_hub.constants import HF_HUB_CACHE
1213
from safetensors.numpy import save_file
1314
from tokenizers import Tokenizer
1415

@@ -96,9 +97,10 @@ def _create_model_card(
9697

9798
def load_pretrained(
9899
folder_or_repo_path: str | Path,
99-
subfolder: str | None = None,
100-
token: str | None = None,
101-
from_sentence_transformers: bool = False,
100+
subfolder: str | None,
101+
token: str | None,
102+
from_sentence_transformers: bool,
103+
force_download: bool,
102104
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
103105
"""
104106
Loads a pretrained model from a folder.
@@ -109,6 +111,8 @@ def load_pretrained(
109111
:param subfolder: The subfolder to load from.
110112
:param token: The huggingface token to use.
111113
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
114+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
115+
already present in the cache.
112116
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113117
:return: The embeddings, tokenizer, config, and metadata.
114118
@@ -122,7 +126,13 @@ def load_pretrained(
122126
tokenizer_file = "tokenizer.json"
123127
config_name = "config.json"
124128

125-
folder_or_repo_path = Path(folder_or_repo_path)
129+
cached_folder = _get_latest_model_path(str(folder_or_repo_path))
130+
if cached_folder and not force_download:
131+
logger.info(f"Found cached model at {cached_folder}, loading from cache.")
132+
folder_or_repo_path = cached_folder
133+
else:
134+
logger.info(f"No cached model found for {folder_or_repo_path}, loading from local or hub.")
135+
folder_or_repo_path = Path(folder_or_repo_path)
126136

127137
local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128138

@@ -139,9 +149,7 @@ def load_pretrained(
139149
if not tokenizer_path.exists():
140150
raise FileNotFoundError(f"Tokenizer file does not exist in {local_folder}")
141151

142-
# README is optional, so this is a bit finicky.
143152
readme_path = local_folder / "README.md"
144-
metadata = _get_metadata_from_readme(readme_path)
145153

146154
else:
147155
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
@@ -150,18 +158,11 @@ def load_pretrained(
150158
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
151159
)
152160
)
153-
154-
try:
155-
readme_path = Path(
156-
huggingface_hub.hf_hub_download(
157-
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
158-
)
161+
readme_path = Path(
162+
huggingface_hub.hf_hub_download(
163+
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
159164
)
160-
metadata = _get_metadata_from_readme(Path(readme_path))
161-
except Exception as e:
162-
# NOTE: we don't want to raise an error here, since the README is optional.
163-
logger.info(f"No README found in the model folder: {e} No model card loaded.")
164-
metadata = {}
165+
)
165166

166167
config_path = Path(
167168
huggingface_hub.hf_hub_download(
@@ -175,10 +176,13 @@ def load_pretrained(
175176
)
176177

177178
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
178-
if from_sentence_transformers:
179-
embeddings = opened_tensor_file.get_tensor("embedding.weight")
179+
embedding_key = "embedding.weight" if from_sentence_transformers else "embeddings"
180+
embeddings = opened_tensor_file.get_tensor(embedding_key)
181+
182+
if readme_path.exists():
183+
metadata = _get_metadata_from_readme(readme_path)
180184
else:
181-
embeddings = opened_tensor_file.get_tensor("embeddings")
185+
metadata = {}
182186

183187
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
184188
config = json.load(open(config_path))
@@ -223,3 +227,28 @@ def push_folder_to_hub(
223227
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
224228

225229
logger.info(f"Pushed model to {repo_id}")
230+
231+
232+
def _get_latest_model_path(model_id: str) -> Path | None:
233+
"""
234+
Gets the latest model path for a given identifier from the hugging face hub cache.
235+
236+
Returns None if there is no cached model. In this case, the model will be downloaded.
237+
"""
238+
# Make path object
239+
cache_dir = Path(HF_HUB_CACHE)
240+
# This is specific to how HF stores the files.
241+
normalized = model_id.replace("/", "--")
242+
repo_dir = cache_dir / f"models--{normalized}" / "snapshots"
243+
244+
if not repo_dir.exists():
245+
return None
246+
247+
# Find all directories.
248+
snapshots = [p for p in repo_dir.iterdir() if p.is_dir()]
249+
if not snapshots:
250+
return None
251+
252+
# Get the latest directory by modification time.
253+
latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
254+
return latest_snapshot

model2vec/model.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from model2vec.quantization import DType, quantize_and_reduce_dim
16-
from model2vec.utils import ProgressParallel, load_local_model
16+
from model2vec.utils import ProgressParallel
1717

1818
PathLike = Union[Path, str]
1919

@@ -156,6 +156,7 @@ def from_pretrained(
156156
subfolder: str | None = None,
157157
quantize_to: str | DType | None = None,
158158
dimensionality: int | None = None,
159+
force_download: bool = True,
159160
) -> StaticModel:
160161
"""
161162
Load a StaticModel from a local path or huggingface hub path.
@@ -171,6 +172,8 @@ def from_pretrained(
171172
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
172173
This is useful if you want to load a model with a lower dimensionality.
173174
Note that this only applies if you have trained your model using mrl or PCA.
175+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
176+
already present in the cache.
174177
:return: A StaticModel.
175178
"""
176179
from model2vec.hf_utils import load_pretrained
@@ -180,6 +183,7 @@ def from_pretrained(
180183
token=token,
181184
from_sentence_transformers=False,
182185
subfolder=subfolder,
186+
force_download=force_download,
183187
)
184188

185189
embeddings = quantize_and_reduce_dim(
@@ -205,6 +209,7 @@ def from_sentence_transformers(
205209
normalize: bool | None = None,
206210
quantize_to: str | DType | None = None,
207211
dimensionality: int | None = None,
212+
force_download: bool = True,
208213
) -> StaticModel:
209214
"""
210215
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -219,6 +224,8 @@ def from_sentence_transformers(
219224
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
220225
This is useful if you want to load a model with a lower dimensionality.
221226
Note that this only applies if you have trained your model using mrl or PCA.
227+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
228+
already present in the cache.
222229
:return: A StaticModel.
223230
"""
224231
from model2vec.hf_utils import load_pretrained
@@ -228,6 +235,7 @@ def from_sentence_transformers(
228235
token=token,
229236
from_sentence_transformers=True,
230237
subfolder=None,
238+
force_download=force_download,
231239
)
232240

233241
embeddings = quantize_and_reduce_dim(
@@ -447,28 +455,3 @@ def push_to_hub(
447455
with TemporaryDirectory() as temp_dir:
448456
self.save_pretrained(temp_dir, model_name=repo_id)
449457
push_folder_to_hub(Path(temp_dir), subfolder=subfolder, repo_id=repo_id, private=private, token=token)
450-
451-
@classmethod
452-
def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
453-
"""
454-
Loads a model from a local path.
455-
456-
You should only use this code path if you are concerned with start-up time.
457-
Loading via the `from_pretrained` method is safer, and auto-downloads, but
458-
also means we import a whole bunch of huggingface code that we don't need.
459-
460-
Additionally, huggingface will check the most recent version of the model,
461-
which can be slow.
462-
463-
:param path: The path to load the model from. The path is a directory saved by the
464-
`save_pretrained` method.
465-
:return: A StaticModel
466-
:raises: ValueError if the path is not a directory.
467-
"""
468-
path = Path(path)
469-
if not path.is_dir():
470-
raise ValueError(f"Path {path} is not a directory.")
471-
472-
embeddings, tokenizer, config = load_local_model(path)
473-
474-
return StaticModel(embeddings, tokenizer, config)

model2vec/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,3 @@ def setup_logging() -> None:
102102
datefmt="%Y-%m-%d %H:%M:%S",
103103
handlers=[RichHandler(rich_tracebacks=True)],
104104
)
105-
106-
107-
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
108-
"""Load a local model."""
109-
embeddings_path = folder / "model.safetensors"
110-
tokenizer_path = folder / "tokenizer.json"
111-
config_path = folder / "config.json"
112-
113-
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
114-
embeddings = opened_tensor_file.get_tensor("embeddings")
115-
116-
if config_path.exists():
117-
config = json.load(open(config_path))
118-
else:
119-
config = {}
120-
121-
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
122-
123-
if len(tokenizer.get_vocab()) != len(embeddings):
124-
logger.warning(
125-
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
126-
)
127-
128-
return embeddings, tokenizer, config

tests/test_model.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def test_encode_as_tokens_empty(
118118
encoded = model.encode_as_sequence("")
119119
assert np.array_equal(encoded, np.zeros(shape=(0, 2), dtype=model.embedding.dtype))
120120

121-
encoded = model.encode_as_sequence(["", ""])
121+
encoded_list = model.encode_as_sequence(["", ""])
122122
out = [np.zeros(shape=(0, 2), dtype=model.embedding.dtype) for _ in range(2)]
123-
assert [np.array_equal(x, y) for x, y in zip(encoded, out)]
123+
assert [np.array_equal(x, y) for x, y in zip(encoded_list, out)]
124124

125125

126126
def test_encode_empty_sentence(
@@ -273,23 +273,3 @@ def test_dim(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: d
273273
model = StaticModel(mock_vectors, mock_tokenizer, mock_config)
274274
assert model.dim == 2
275275
assert model.dim == model.embedding.shape[1]
276-
277-
278-
def test_local_load_from_model(mock_tokenizer: Tokenizer) -> None:
279-
"""Test local load from a model."""
280-
x = np.ones((mock_tokenizer.get_vocab_size(), 2))
281-
with TemporaryDirectory() as tempdir:
282-
tempdir_path = Path(tempdir)
283-
safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
284-
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
285-
286-
model = StaticModel.load_local(tempdir_path)
287-
assert model.embedding.shape == x.shape
288-
assert model.tokenizer.to_str() == mock_tokenizer.to_str()
289-
assert model.config == {"normalize": False}
290-
291-
292-
def test_local_load_from_model_no_folder() -> None:
293-
"""Test local load from a model with no folder."""
294-
with pytest.raises(ValueError):
295-
StaticModel.load_local("woahbuddy_relax_this_is_just_a_test")

tests/test_utils.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from model2vec.distill.utils import select_optimal_device
1616
from model2vec.hf_utils import _get_metadata_from_readme
17-
from model2vec.utils import get_package_extras, importable, load_local_model
17+
from model2vec.utils import get_package_extras, importable
1818

1919

2020
def test__get_metadata_from_readme_not_exists() -> None:
@@ -78,44 +78,3 @@ def test_get_package_extras() -> None:
7878
def test_get_package_extras_empty() -> None:
7979
"""Test package extras with an empty package."""
8080
assert not list(get_package_extras("tqdm", ""))
81-
82-
83-
@pytest.mark.parametrize(
84-
"config, expected",
85-
[
86-
({"dog": "cat"}, {"dog": "cat"}),
87-
({}, {}),
88-
(None, {}),
89-
],
90-
)
91-
def test_local_load(mock_tokenizer: Tokenizer, config: dict[str, Any], expected: dict[str, Any]) -> None:
92-
"""Test local loading."""
93-
x = np.ones((mock_tokenizer.get_vocab_size(), 2))
94-
95-
with TemporaryDirectory() as tempdir:
96-
tempdir_path = Path(tempdir)
97-
safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
98-
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
99-
if config is not None:
100-
json.dump(config, open(tempdir_path / "config.json", "w"))
101-
arr, tokenizer, config = load_local_model(tempdir_path)
102-
assert config == expected
103-
assert tokenizer.to_str() == mock_tokenizer.to_str()
104-
assert arr.shape == x.shape
105-
106-
107-
def test_local_load_mismatch(mock_tokenizer: Tokenizer, caplog: pytest.LogCaptureFixture) -> None:
108-
"""Test local loading."""
109-
x = np.ones((10, 2))
110-
111-
with TemporaryDirectory() as tempdir:
112-
tempdir_path = Path(tempdir)
113-
safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
114-
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
115-
116-
load_local_model(tempdir_path)
117-
expected = (
118-
f"Number of tokens does not match number of embeddings: `{len(mock_tokenizer.get_vocab())}` vs `{len(x)}`"
119-
)
120-
assert len(caplog.records) == 1
121-
assert caplog.records[0].message == expected

0 commit comments

Comments
 (0)