Skip to content

Commit dde95cd

Browse files
committed
merge
2 parents 3f39da4 + 7bf0bf0 commit dde95cd

20 files changed

+854
-510
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,8 @@ jobs:
99
runs-on: ${{ matrix.os }}
1010
strategy:
1111
matrix:
12-
os: ["ubuntu-latest", "windows-latest"]
12+
os: ["ubuntu-latest"]
1313
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
14-
exclude:
15-
- os: windows-latest
16-
python-version: "3.9"
17-
- os: windows-latest
18-
python-version: "3.11"
19-
- os: windows-latest
20-
python-version: "3.12"
21-
- os: windows-latest
22-
python-version: "3.13"
2314
fail-fast: false
2415

2516
steps:
@@ -31,19 +22,7 @@ jobs:
3122
python-version: ${{ matrix.python-version }}
3223
allow-prereleases: true
3324

34-
# Step for Windows: Create and activate a virtual environment
35-
- name: Create and activate a virtual environment (Windows)
36-
if: ${{ runner.os == 'Windows' }}
37-
run: |
38-
irm https://astral.sh/uv/install.ps1 | iex
39-
$env:Path = "C:\Users\runneradmin\.local\bin;$env:Path"
40-
uv venv .venv
41-
"VIRTUAL_ENV=.venv" | Out-File -FilePath $env:GITHUB_ENV -Append
42-
"$PWD/.venv/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Append
43-
44-
# Step for Unix: Create and activate a virtual environment
4525
- name: Create and activate a virtual environment (Unix)
46-
if: ${{ runner.os != 'Windows' }}
4726
run: |
4827
curl -LsSf https://astral.sh/uv/install.sh | sh
4928
uv venv .venv

model2vec/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from model2vec.model import StaticModel
1+
from model2vec.model import StaticModel, quantize_model
22
from model2vec.version import __version__
33

4-
__all__ = ["StaticModel", "__version__"]
4+
__all__ = ["StaticModel", "quantize_model", "__version__"]

model2vec/distill/distillation.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from model2vec.quantization import DType, quantize_embeddings
1919
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
2020
from model2vec.tokenizer.tokenizer import _patch_tokenizer
21+
from model2vec.vocabulary_quantization import quantize_vocabulary
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -32,6 +33,7 @@ def distill_from_model(
3233
token_remove_pattern: str | None = r"\[unused\d+\]",
3334
quantize_to: DType | str = DType.Float16,
3435
lower_case: bool = True,
36+
vocabulary_quantization: int | None = None,
3537
) -> StaticModel:
3638
"""
3739
Distill a staticmodel from a sentence transformer.
@@ -56,6 +58,7 @@ def distill_from_model(
5658
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
5759
:param lower_case: If this is set, all tokens in the model vocabulary will be converted to lowercase, and
5860
a lowercase normalizer will be inserted. This almost always improves performance.
61+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
5962
:return: A StaticModel
6063
:raises: ValueError if the vocabulary is empty after preprocessing.
6164
@@ -106,8 +109,16 @@ def distill_from_model(
106109
pad_token_id = vocab[pad_token]
107110
embeddings = create_embeddings(tokenized=token_ids, model=model, device=device, pad_token_id=pad_token_id)
108111

109-
# Post process the embeddings by applying PCA and Zipf weighting.
110-
embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
112+
if vocabulary_quantization is not None:
113+
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
114+
embeddings, token_mapping, weights = quantize_vocabulary(
115+
n_clusters=vocabulary_quantization, weights=weights, embeddings=np.asarray(embeddings)
116+
)
117+
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
118+
else:
119+
# Post-process the embeddings.
120+
embeddings, weights = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
121+
token_mapping = None
111122
# Quantize the embeddings.
112123
embeddings = quantize_embeddings(embeddings, quantize_to)
113124

@@ -140,6 +151,8 @@ def distill_from_model(
140151

141152
return StaticModel(
142153
vectors=embeddings,
154+
weights=weights,
155+
token_mapping=token_mapping,
143156
tokenizer=backend_tokenizer,
144157
config=config,
145158
base_model_name=model_name,
@@ -186,6 +199,7 @@ def distill(
186199
trust_remote_code: bool = False,
187200
quantize_to: DType | str = DType.Float16,
188201
lower_case: bool = True,
202+
vocabulary_quantization: int | None = None,
189203
) -> StaticModel:
190204
"""
191205
Distill a staticmodel from a sentence transformer.
@@ -209,6 +223,7 @@ def distill(
209223
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
210224
:param lower_case: If this is set, all tokens in the model vocabulary will be converted to lowercase, and
211225
a lowercase normalizer will be inserted. This almost always improves performance.
226+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
212227
:return: A StaticModel
213228
214229
"""
@@ -228,4 +243,5 @@ def distill(
228243
sif_coefficient=sif_coefficient,
229244
quantize_to=quantize_to,
230245
lower_case=lower_case,
246+
vocabulary_quantization=vocabulary_quantization,
231247
)

model2vec/distill/inference.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from sklearn.decomposition import PCA
1212
from torch.nn.utils.rnn import pad_sequence
1313
from tqdm import tqdm
14-
from transformers import PreTrainedModel
1514
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15+
from transformers.modeling_utils import PreTrainedModel
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -46,7 +46,7 @@ def create_embeddings(
4646
:param pad_token_id: The pad token id. Used to pad sequences.
4747
:return: The output embeddings.
4848
"""
49-
model = model.to(device) # type: ignore
49+
model = model.to(device) # type: ignore # Transformers error
5050

5151
out_weights: np.ndarray
5252
intermediate_weights: list[np.ndarray] = []
@@ -117,7 +117,7 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
117117

118118
def post_process_embeddings(
119119
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
120-
) -> np.ndarray:
120+
) -> tuple[np.ndarray, np.ndarray]:
121121
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
122122
if pca_dims is not None:
123123
if pca_dims == "auto":
@@ -154,6 +154,8 @@ def post_process_embeddings(
154154
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
155155
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
156156
proba = inv_rank / np.sum(inv_rank)
157-
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
157+
weight = sif_coefficient / (sif_coefficient + proba)
158+
else:
159+
weight = np.ones(embeddings.shape[0])
158160

159-
return embeddings
161+
return embeddings, weight

model2vec/hf_utils.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import safetensors
1111
from huggingface_hub import ModelCard, ModelCardData
12+
from huggingface_hub.constants import HF_HUB_CACHE
1213
from safetensors.numpy import save_file
1314
from tokenizers import Tokenizer
1415

@@ -24,6 +25,8 @@ def save_pretrained(
2425
config: dict[str, Any],
2526
create_model_card: bool = True,
2627
subfolder: str | None = None,
28+
weights: np.ndarray | None = None,
29+
mapping: np.ndarray | None = None,
2730
**kwargs: Any,
2831
) -> None:
2932
"""
@@ -35,11 +38,20 @@ def save_pretrained(
3538
:param config: A metadata config.
3639
:param create_model_card: Whether to create a model card.
3740
:param subfolder: The subfolder to save the model in.
41+
:param weights: The weights of the model. If None, no weights are saved.
42+
:param mapping: The token mapping of the model. If None, there is no token mapping.
3843
:param **kwargs: Any additional arguments.
3944
"""
4045
folder_path = folder_path / subfolder if subfolder else folder_path
4146
folder_path.mkdir(exist_ok=True, parents=True)
42-
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
47+
48+
model_weights = {"embeddings": embeddings}
49+
if weights is not None:
50+
model_weights["weights"] = weights
51+
if mapping is not None:
52+
model_weights["mapping"] = mapping
53+
54+
save_file(model_weights, folder_path / "model.safetensors")
4355
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
4456
json.dump(config, open(folder_path / "config.json", "w"), indent=4)
4557

@@ -96,10 +108,11 @@ def _create_model_card(
96108

97109
def load_pretrained(
98110
folder_or_repo_path: str | Path,
99-
subfolder: str | None = None,
100-
token: str | None = None,
101-
from_sentence_transformers: bool = False,
102-
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
111+
subfolder: str | None,
112+
token: str | None,
113+
from_sentence_transformers: bool,
114+
force_download: bool,
115+
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any], np.ndarray | None, np.ndarray | None]:
103116
"""
104117
Loads a pretrained model from a folder.
105118
@@ -109,8 +122,10 @@ def load_pretrained(
109122
:param subfolder: The subfolder to load from.
110123
:param token: The huggingface token to use.
111124
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
125+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
126+
already present in the cache.
112127
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113-
:return: The embeddings, tokenizer, config, and metadata.
128+
:return: The embeddings, tokenizer, config, metadata, weights and mapping.
114129
115130
"""
116131
if from_sentence_transformers:
@@ -122,7 +137,13 @@ def load_pretrained(
122137
tokenizer_file = "tokenizer.json"
123138
config_name = "config.json"
124139

125-
folder_or_repo_path = Path(folder_or_repo_path)
140+
cached_folder = _get_latest_model_path(str(folder_or_repo_path))
141+
if cached_folder and not force_download:
142+
logger.info(f"Found cached model at {cached_folder}, loading from cache.")
143+
folder_or_repo_path = cached_folder
144+
else:
145+
logger.info(f"No cached model found for {folder_or_repo_path}, loading from local or hub.")
146+
folder_or_repo_path = Path(folder_or_repo_path)
126147

127148
local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128149

@@ -139,9 +160,7 @@ def load_pretrained(
139160
if not tokenizer_path.exists():
140161
raise FileNotFoundError(f"Tokenizer file does not exist in {local_folder}")
141162

142-
# README is optional, so this is a bit finicky.
143163
readme_path = local_folder / "README.md"
144-
metadata = _get_metadata_from_readme(readme_path)
145164

146165
else:
147166
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
@@ -150,18 +169,11 @@ def load_pretrained(
150169
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
151170
)
152171
)
153-
154-
try:
155-
readme_path = Path(
156-
huggingface_hub.hf_hub_download(
157-
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
158-
)
172+
readme_path = Path(
173+
huggingface_hub.hf_hub_download(
174+
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
159175
)
160-
metadata = _get_metadata_from_readme(Path(readme_path))
161-
except Exception as e:
162-
# NOTE: we don't want to raise an error here, since the README is optional.
163-
logger.info(f"No README found in the model folder: {e} No model card loaded.")
164-
metadata = {}
176+
)
165177

166178
config_path = Path(
167179
huggingface_hub.hf_hub_download(
@@ -175,20 +187,27 @@ def load_pretrained(
175187
)
176188

177189
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
178-
if from_sentence_transformers:
179-
embeddings = opened_tensor_file.get_tensor("embedding.weight")
190+
embedding_name = "embedding.weight" if from_sentence_transformers else "embeddings"
191+
embeddings = opened_tensor_file.get_tensor(embedding_name)
192+
try:
193+
weights = opened_tensor_file.get_tensor("weights")
194+
except Exception:
195+
# Bare except because safetensors does not export its own errors.
196+
weights = None
197+
try:
198+
mapping = opened_tensor_file.get_tensor("mapping")
199+
except Exception:
200+
mapping = None
201+
202+
if readme_path.exists():
203+
metadata = _get_metadata_from_readme(readme_path)
180204
else:
181-
embeddings = opened_tensor_file.get_tensor("embeddings")
205+
metadata = {}
182206

183207
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
184208
config = json.load(open(config_path))
185209

186-
if len(tokenizer.get_vocab()) != len(embeddings):
187-
logger.warning(
188-
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
189-
)
190-
191-
return embeddings, tokenizer, config, metadata
210+
return embeddings, tokenizer, config, metadata, weights, mapping
192211

193212

194213
def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
@@ -223,3 +242,28 @@ def push_folder_to_hub(
223242
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
224243

225244
logger.info(f"Pushed model to {repo_id}")
245+
246+
247+
def _get_latest_model_path(model_id: str) -> Path | None:
248+
"""
249+
Gets the latest model path for a given identifier from the hugging face hub cache.
250+
251+
Returns None if there is no cached model. In this case, the model will be downloaded.
252+
"""
253+
# Make path object
254+
cache_dir = Path(HF_HUB_CACHE)
255+
# This is specific to how HF stores the files.
256+
normalized = model_id.replace("/", "--")
257+
repo_dir = cache_dir / f"models--{normalized}" / "snapshots"
258+
259+
if not repo_dir.exists():
260+
return None
261+
262+
# Find all directories.
263+
snapshots = [p for p in repo_dir.iterdir() if p.is_dir()]
264+
if not snapshots:
265+
return None
266+
267+
# Get the latest directory by modification time.
268+
latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
269+
return latest_snapshot

model2vec/inference/model.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from pathlib import Path
55
from tempfile import TemporaryDirectory
6-
from typing import Sequence, TypeVar
6+
from typing import Sequence, TypeVar, Union, cast
77

88
import huggingface_hub
99
import numpy as np
@@ -273,14 +273,14 @@ def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> Non
273273
)
274274

275275

276-
def _is_multi_label_shaped(y: LabelType) -> bool:
276+
def _is_multi_label_shaped(y: list[int] | list[str] | list[list[int]] | list[list[str]]) -> bool:
277277
"""Check if the labels are in a multi-label shape."""
278278
return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))
279279

280280

281281
def evaluate_single_or_multi_label(
282282
predictions: np.ndarray,
283-
y: LabelType,
283+
y: list[int] | list[str] | list[list[int]] | list[list[str]],
284284
output_dict: bool = False,
285285
) -> str | dict[str, dict[str, float]]:
286286
"""
@@ -292,16 +292,22 @@ def evaluate_single_or_multi_label(
292292
:return: A classification report.
293293
"""
294294
if _is_multi_label_shaped(y):
295+
# Cast because the type checker doesn't understand that y is a list of lists.
296+
y = cast(Union[list[list[str]], list[list[int]]], y)
295297
classes = sorted(set([label for labels in y for label in labels]))
296298
mlb = MultiLabelBinarizer(classes=classes)
297-
y = mlb.fit_transform(y)
298-
predictions = mlb.transform(predictions)
299-
elif isinstance(y[0], (str, int)):
300-
classes = sorted(set(y))
299+
y_transformed = mlb.fit_transform(y)
300+
predictions_transformed = mlb.transform(predictions)
301+
else:
302+
if all(isinstance(label, (str, int)) for label in y):
303+
y = cast(Union[list[str], list[int]], y)
304+
classes = sorted(set(y))
305+
y_transformed = np.array(y)
306+
predictions_transformed = np.array(predictions)
301307

302308
report = classification_report(
303-
y,
304-
predictions,
309+
y_transformed,
310+
predictions_transformed,
305311
output_dict=output_dict,
306312
zero_division=0,
307313
)

0 commit comments

Comments
 (0)