Skip to content

Commit 7bf0bf0

Browse files
authored
feat: add vocabulary quantization (#271)
* remove multiword warning * add superbpe tokenizers * fix: pretokenize tokens before checking vocabulary * feat: add quantization * wip * wip * wip * fixes * fixes * fix issue with mwe * wip * wip * wip * wip * wip * wip * fixes * fix: refactor quantization * fix: refactor quantization * wip * wip * typing * fixes * fix typing/linting * add quantization helper to top * change init to random * fix: annotations import * fix test import * import Union for 3.9 * fix: union again * store all relevant info in safetensors * make weights float in training
1 parent 13095c9 commit 7bf0bf0

19 files changed

+425
-100
lines changed

model2vec/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from model2vec.model import StaticModel
1+
from model2vec.model import StaticModel, quantize_model
22
from model2vec.version import __version__
33

4-
__all__ = ["StaticModel", "__version__"]
4+
__all__ = ["StaticModel", "quantize_model", "__version__"]

model2vec/distill/distillation.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
from typing import Optional, cast
77

88
import numpy as np
9-
from huggingface_hub import model_info
10-
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
9+
from huggingface_hub.hf_api import model_info
10+
from transformers import AutoModel, AutoTokenizer
11+
from transformers.modeling_utils import PreTrainedModel
12+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1113

1214
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
1315
from model2vec.distill.utils import select_optimal_device
1416
from model2vec.model import StaticModel
1517
from model2vec.quantization import DType, quantize_embeddings
1618
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
19+
from model2vec.vocabulary_quantization import quantize_vocabulary
1720

1821
logger = logging.getLogger(__name__)
1922

@@ -29,6 +32,7 @@ def distill_from_model(
2932
token_remove_pattern: str | None = r"\[unused\d+\]",
3033
quantize_to: DType | str = DType.Float16,
3134
use_subword: bool | None = None,
35+
vocabulary_quantization: int | None = None,
3236
) -> StaticModel:
3337
"""
3438
Distill a staticmodel from a sentence transformer.
@@ -54,6 +58,7 @@ def distill_from_model(
5458
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
5559
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
5660
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
61+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
5762
:return: A StaticModel
5863
:raises: ValueError if the vocabulary is empty after preprocessing.
5964
@@ -103,7 +108,6 @@ def distill_from_model(
103108

104109
# Replace the vocabulary in the tokenizer with the new vocabulary.
105110
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
106-
107111
logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
108112
# Convert tokens to IDs
109113
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
@@ -113,8 +117,16 @@ def distill_from_model(
113117
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
114118
)
115119

116-
# Post process the embeddings by applying PCA and Zipf weighting.
117-
embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
120+
if vocabulary_quantization is not None:
121+
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
122+
embeddings, token_mapping, weights = quantize_vocabulary(
123+
n_clusters=vocabulary_quantization, weights=weights, embeddings=np.asarray(embeddings)
124+
)
125+
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
126+
else:
127+
# Post-process the embeddings.
128+
embeddings, weights = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
129+
token_mapping = None
118130
# Quantize the embeddings.
119131
embeddings = quantize_embeddings(embeddings, quantize_to)
120132

@@ -148,6 +160,8 @@ def distill_from_model(
148160

149161
return StaticModel(
150162
vectors=embeddings,
163+
weights=weights,
164+
token_mapping=token_mapping,
151165
tokenizer=backend_tokenizer,
152166
config=config,
153167
base_model_name=model_name,
@@ -211,6 +225,7 @@ def distill(
211225
trust_remote_code: bool = False,
212226
quantize_to: DType | str = DType.Float16,
213227
use_subword: bool | None = None,
228+
vocabulary_quantization: int | None = None,
214229
) -> StaticModel:
215230
"""
216231
Distill a staticmodel from a sentence transformer.
@@ -235,6 +250,7 @@ def distill(
235250
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
236251
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
237252
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
253+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
238254
:return: A StaticModel
239255
240256
"""
@@ -255,4 +271,5 @@ def distill(
255271
sif_coefficient=sif_coefficient,
256272
quantize_to=quantize_to,
257273
use_subword=use_subword,
274+
vocabulary_quantization=vocabulary_quantization,
258275
)

model2vec/distill/inference.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from sklearn.decomposition import PCA
1212
from torch.nn.utils.rnn import pad_sequence
1313
from tqdm import tqdm
14-
from transformers import PreTrainedModel
1514
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15+
from transformers.modeling_utils import PreTrainedModel
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -46,7 +46,7 @@ def create_embeddings(
4646
:param pad_token_id: The pad token id. Used to pad sequences.
4747
:return: The output embeddings.
4848
"""
49-
model = model.to(device) # type: ignore
49+
model = model.to(device) # type: ignore # Transformers error
5050

5151
out_weights: np.ndarray
5252
intermediate_weights: list[np.ndarray] = []
@@ -98,7 +98,7 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
9898
"""
9999
encodings = {k: v.to(model.device) for k, v in encodings.items()}
100100
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101-
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # typing is wrong.
101+
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # False positive
102102
# NOTE: If the dtype is bfloat 16, we convert to float32,
103103
# because numpy does not suport bfloat16
104104
# See here: https://github.com/numpy/numpy/issues/19808
@@ -116,7 +116,7 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
116116

117117
def post_process_embeddings(
118118
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
119-
) -> np.ndarray:
119+
) -> tuple[np.ndarray, np.ndarray]:
120120
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
121121
if pca_dims is not None:
122122
if pca_dims == "auto":
@@ -153,6 +153,8 @@ def post_process_embeddings(
153153
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
154154
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
155155
proba = inv_rank / np.sum(inv_rank)
156-
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
156+
weight = sif_coefficient / (sif_coefficient + proba)
157+
else:
158+
weight = np.ones(embeddings.shape[0])
157159

158-
return embeddings
160+
return embeddings, weight

model2vec/hf_utils.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def save_pretrained(
2525
config: dict[str, Any],
2626
create_model_card: bool = True,
2727
subfolder: str | None = None,
28+
weights: np.ndarray | None = None,
29+
mapping: np.ndarray | None = None,
2830
**kwargs: Any,
2931
) -> None:
3032
"""
@@ -36,11 +38,20 @@ def save_pretrained(
3638
:param config: A metadata config.
3739
:param create_model_card: Whether to create a model card.
3840
:param subfolder: The subfolder to save the model in.
41+
:param weights: The weights of the model. If None, no weights are saved.
42+
:param mapping: The token mapping of the model. If None, there is no token mapping.
3943
:param **kwargs: Any additional arguments.
4044
"""
4145
folder_path = folder_path / subfolder if subfolder else folder_path
4246
folder_path.mkdir(exist_ok=True, parents=True)
43-
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
47+
48+
model_weights = {"embeddings": embeddings}
49+
if weights is not None:
50+
model_weights["weights"] = weights
51+
if mapping is not None:
52+
model_weights["mapping"] = mapping
53+
54+
save_file(model_weights, folder_path / "model.safetensors")
4455
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
4556
json.dump(config, open(folder_path / "config.json", "w"), indent=4)
4657

@@ -101,7 +112,7 @@ def load_pretrained(
101112
token: str | None,
102113
from_sentence_transformers: bool,
103114
force_download: bool,
104-
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
115+
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any], np.ndarray | None, np.ndarray | None]:
105116
"""
106117
Loads a pretrained model from a folder.
107118
@@ -114,7 +125,7 @@ def load_pretrained(
114125
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
115126
already present in the cache.
116127
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
117-
:return: The embeddings, tokenizer, config, and metadata.
128+
:return: The embeddings, tokenizer, config, metadata, weights and mapping.
118129
119130
"""
120131
if from_sentence_transformers:
@@ -176,8 +187,17 @@ def load_pretrained(
176187
)
177188

178189
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
179-
embedding_key = "embedding.weight" if from_sentence_transformers else "embeddings"
180-
embeddings = opened_tensor_file.get_tensor(embedding_key)
190+
embedding_name = "embedding.weight" if from_sentence_transformers else "embeddings"
191+
embeddings = opened_tensor_file.get_tensor(embedding_name)
192+
try:
193+
weights = opened_tensor_file.get_tensor("weights")
194+
except Exception:
195+
# Bare except because safetensors does not export its own errors.
196+
weights = None
197+
try:
198+
mapping = opened_tensor_file.get_tensor("mapping")
199+
except Exception:
200+
mapping = None
181201

182202
if readme_path.exists():
183203
metadata = _get_metadata_from_readme(readme_path)
@@ -187,12 +207,7 @@ def load_pretrained(
187207
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
188208
config = json.load(open(config_path))
189209

190-
if len(tokenizer.get_vocab()) != len(embeddings):
191-
logger.warning(
192-
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
193-
)
194-
195-
return embeddings, tokenizer, config, metadata
210+
return embeddings, tokenizer, config, metadata, weights, mapping
196211

197212

198213
def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:

model2vec/inference/model.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from pathlib import Path
55
from tempfile import TemporaryDirectory
6-
from typing import Sequence, TypeVar
6+
from typing import Sequence, TypeVar, Union, cast
77

88
import huggingface_hub
99
import numpy as np
@@ -273,14 +273,14 @@ def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> Non
273273
)
274274

275275

276-
def _is_multi_label_shaped(y: LabelType) -> bool:
276+
def _is_multi_label_shaped(y: list[int] | list[str] | list[list[int]] | list[list[str]]) -> bool:
277277
"""Check if the labels are in a multi-label shape."""
278278
return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))
279279

280280

281281
def evaluate_single_or_multi_label(
282282
predictions: np.ndarray,
283-
y: LabelType,
283+
y: list[int] | list[str] | list[list[int]] | list[list[str]],
284284
output_dict: bool = False,
285285
) -> str | dict[str, dict[str, float]]:
286286
"""
@@ -292,16 +292,22 @@ def evaluate_single_or_multi_label(
292292
:return: A classification report.
293293
"""
294294
if _is_multi_label_shaped(y):
295+
# Cast because the type checker doesn't understand that y is a list of lists.
296+
y = cast(Union[list[list[str]], list[list[int]]], y)
295297
classes = sorted(set([label for labels in y for label in labels]))
296298
mlb = MultiLabelBinarizer(classes=classes)
297-
y = mlb.fit_transform(y)
298-
predictions = mlb.transform(predictions)
299-
elif isinstance(y[0], (str, int)):
300-
classes = sorted(set(y))
299+
y_transformed = mlb.fit_transform(y)
300+
predictions_transformed = mlb.transform(predictions)
301+
else:
302+
if all(isinstance(label, (str, int)) for label in y):
303+
y = cast(Union[list[str], list[int]], y)
304+
classes = sorted(set(y))
305+
y_transformed = np.array(y)
306+
predictions_transformed = np.array(predictions)
301307

302308
report = classification_report(
303-
y,
304-
predictions,
309+
y_transformed,
310+
predictions_transformed,
305311
output_dict=output_dict,
306312
zero_division=0,
307313
)

0 commit comments

Comments
 (0)