Skip to content

Commit 9d41dfc

Browse files
committed
fixes
1 parent 3b61fec commit 9d41dfc

File tree

8 files changed

+91
-93
lines changed

8 files changed

+91
-93
lines changed

model2vec/distill/distillation.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
import numpy as np
99
from huggingface_hub.hf_api import model_info
10-
from sklearn.cluster import KMeans
1110
from transformers import AutoModel, AutoTokenizer
1211
from transformers.modeling_utils import PreTrainedModel
1312
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1413

1514
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
1615
from model2vec.distill.utils import select_optimal_device
1716
from model2vec.model import StaticModel
18-
from model2vec.quantization import DType, quantize_embeddings
17+
from model2vec.quantization import DType, quantize_embeddings, quantize_vocabulary
1918
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
2019

2120
logger = logging.getLogger(__name__)
@@ -58,6 +57,7 @@ def distill_from_model(
5857
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
5958
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
6059
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
60+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
6161
:return: A StaticModel
6262
:raises: ValueError if the vocabulary is empty after preprocessing.
6363
@@ -118,19 +118,14 @@ def distill_from_model(
118118

119119
if vocabulary_quantization is not None:
120120
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
121-
km = KMeans(vocabulary_quantization, random_state=42)
122-
km.fit(embeddings)
123-
clustered_embeddings = km.predict(embeddings)
124-
mapping = {idx: int(x) for idx, x in enumerate(clustered_embeddings)}
125-
126-
embeddings = km.cluster_centers_
121+
embeddings, token_mapping, weights = quantize_vocabulary(
122+
n_clusters=vocabulary_quantization, weights=weights, embeddings=np.asarray(embeddings)
123+
)
127124
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
128125
else:
129126
# Post-process the embeddings.
130-
embeddings, weights = post_process_embeddings(
131-
np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient
132-
)
133-
mapping = {idx: idx for idx in range(len(all_tokens))}
127+
embeddings, weights = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
128+
token_mapping = None
134129
# Quantize the embeddings.
135130
embeddings = quantize_embeddings(embeddings, quantize_to)
136131

@@ -165,7 +160,7 @@ def distill_from_model(
165160
return StaticModel(
166161
vectors=embeddings,
167162
weights=weights,
168-
token_mapping=mapping,
163+
token_mapping=token_mapping,
169164
tokenizer=backend_tokenizer,
170165
config=config,
171166
base_model_name=model_name,
@@ -254,6 +249,7 @@ def distill(
254249
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
255250
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
256251
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
252+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
257253
:return: A StaticModel
258254
259255
"""

model2vec/hf_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def save_pretrained(
3636
:param config: A metadata config.
3737
:param create_model_card: Whether to create a model card.
3838
:param subfolder: The subfolder to save the model in.
39+
:param weights: The weights of the model. If None, no weights are saved.
3940
:param **kwargs: Any additional arguments.
4041
"""
4142
folder_path = folder_path / subfolder if subfolder else folder_path
@@ -195,11 +196,6 @@ def load_pretrained(
195196
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
196197
config = json.load(open(config_path))
197198

198-
if len(tokenizer.get_vocab()) != len(embeddings):
199-
logger.warning(
200-
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
201-
)
202-
203199
return embeddings, tokenizer, config, metadata, weights
204200

205201

model2vec/model.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from logging import getLogger
66
from pathlib import Path
77
from tempfile import TemporaryDirectory
8-
from typing import Any, Iterator, Sequence, Union, overload
8+
from typing import Any, Iterator, Sequence, Union, cast, overload
99

1010
import numpy as np
1111
from joblib import delayed
1212
from tokenizers import Encoding, Tokenizer
1313
from tqdm import tqdm
1414

15-
from model2vec.quantization import DType, quantize_and_reduce_dim, vocabulary_quantization
15+
from model2vec.quantization import DType, quantize_and_reduce_dim, quantize_vocabulary
1616
from model2vec.utils import ProgressParallel, load_local_model
1717

1818
PathLike = Union[Path, str]
@@ -25,12 +25,12 @@ def __init__(
2525
self,
2626
vectors: np.ndarray,
2727
tokenizer: Tokenizer,
28-
weights: np.ndarray | None = None,
29-
token_mapping: dict[int, int] | None = None,
3028
config: dict[str, Any] | None = None,
3129
normalize: bool | None = None,
3230
base_model_name: str | None = None,
3331
language: list[str] | None = None,
32+
weights: np.ndarray | None = None,
33+
token_mapping: list[int] | None = None,
3434
) -> None:
3535
"""
3636
Initialize the StaticModel.
@@ -41,6 +41,12 @@ def __init__(
4141
:param normalize: Whether to normalize the embeddings.
4242
:param base_model_name: The used base model name. Used for creating a model card.
4343
:param language: The language of the model. Used for creating a model card.
44+
:param weights: The weights to use for the embeddings. If None, no weights are used.
45+
We always assume the norm of the embeddings is an implicit weight anyway.
46+
This is only used for models that have undergone vocabulary quantization.
47+
:param token_mapping: A mapping from token ids to indices in the vectors.
48+
If None, we don't remap the tokens during inference.
49+
This is only used for models that have undergone vocabulary quantization.
4450
:raises: ValueError if the number of tokens does not match the number of vectors.
4551
"""
4652
super().__init__()
@@ -55,7 +61,9 @@ def __init__(
5561

5662
self.embedding = vectors
5763
self.weights = weights
58-
self.token_mapping = token_mapping
64+
# Convert to an array for fast lookups
65+
# We can't use or short circuit here because np.ndarray as booleans are ambiguous.
66+
self.token_mapping = None if token_mapping is None else np.asarray(token_mapping)
5967

6068
self.tokenizer = tokenizer
6169
self.unk_token_id: int | None
@@ -114,7 +122,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
114122
from model2vec.hf_utils import save_pretrained
115123

116124
if self.token_mapping is not None:
117-
self.config["token_mapping"] = list(self.token_mapping.items())
125+
self.config["token_mapping"] = self.token_mapping.tolist()
118126

119127
save_pretrained(
120128
folder_path=Path(path),
@@ -167,7 +175,7 @@ def from_pretrained(
167175
subfolder: str | None = None,
168176
quantize_to: str | DType | None = None,
169177
dimensionality: int | None = None,
170-
quantize_vocabulary: int | None = None,
178+
vocabulary_quantization: int | None = None,
171179
) -> StaticModel:
172180
"""
173181
Load a StaticModel from a local path or huggingface hub path.
@@ -183,6 +191,7 @@ def from_pretrained(
183191
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
184192
This is useful if you want to load a model with a lower dimensionality.
185193
Note that this only applies if you have trained your model using mrl or PCA.
194+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization.
186195
:return: A StaticModel.
187196
"""
188197
from model2vec.hf_utils import load_pretrained
@@ -194,31 +203,27 @@ def from_pretrained(
194203
subfolder=subfolder,
195204
)
196205

206+
# Quantize the vocabulary at full precision and dimensionality
207+
if vocabulary_quantization is not None:
208+
embeddings, token_mapping, weights = quantize_vocabulary(
209+
n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings
210+
)
211+
else:
212+
token_mapping = config.pop("token_mapping", None)
213+
214+
# Reduce dimensionality and quantize if requested
197215
embeddings = quantize_and_reduce_dim(
198216
embeddings=embeddings,
199217
quantize_to=quantize_to,
200218
dimensionality=dimensionality,
201219
)
202220

203-
if quantize_vocabulary is not None:
204-
embeddings, token_mapping, weights = vocabulary_quantization(
205-
n_clusters=quantize_vocabulary, weights=weights, embeddings=embeddings
206-
)
207-
else:
208-
token_mapping = config.pop("token_mapping", None)
209-
if isinstance(token_mapping, list):
210-
# If the token mapping is a list, convert it to a dict
211-
token_mapping = {int(k): int(v) for k, v in token_mapping}
212-
elif token_mapping is None:
213-
# If no token mapping is provided, use the default mapping
214-
token_mapping = {i: i for i in range(len(embeddings))}
215-
216221
return cls(
217-
embeddings,
218-
tokenizer,
219-
weights,
220-
token_mapping,
221-
config,
222+
vectors=embeddings,
223+
tokenizer=tokenizer,
224+
weights=weights,
225+
token_mapping=token_mapping,
226+
config=config,
222227
normalize=normalize,
223228
base_model_name=metadata.get("base_model"),
224229
language=metadata.get("language"),
@@ -232,6 +237,7 @@ def from_sentence_transformers(
232237
normalize: bool | None = None,
233238
quantize_to: str | DType | None = None,
234239
dimensionality: int | None = None,
240+
vocabulary_quantization: int | None = None,
235241
) -> StaticModel:
236242
"""
237243
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -246,6 +252,7 @@ def from_sentence_transformers(
246252
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
247253
This is useful if you want to load a model with a lower dimensionality.
248254
Note that this only applies if you have trained your model using mrl or PCA.
255+
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization.
249256
:return: A StaticModel.
250257
"""
251258
from model2vec.hf_utils import load_pretrained
@@ -254,26 +261,29 @@ def from_sentence_transformers(
254261
folder_or_repo_path=path,
255262
token=token,
256263
from_sentence_transformers=True,
257-
subfolder=None,
258264
)
259265

266+
# Quantize the vocabulary at full precision and dimensionality
267+
if vocabulary_quantization is not None:
268+
embeddings, token_mapping, weights = quantize_vocabulary(
269+
n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings
270+
)
271+
else:
272+
token_mapping = config.pop("token_mapping", None)
273+
274+
# Reduce dimensionality and quantize if requested
260275
embeddings = quantize_and_reduce_dim(
261276
embeddings=embeddings,
262277
quantize_to=quantize_to,
263278
dimensionality=dimensionality,
264279
)
265280

266-
token_mapping = config.pop("token_mapping", None)
267-
if token_mapping is None:
268-
# If no token mapping is provided, use the default mapping
269-
token_mapping = {i: i for i in range(len(embeddings))}
270-
271281
return cls(
272-
embeddings,
273-
tokenizer,
274-
weights,
275-
token_mapping,
276-
config,
282+
vectors=embeddings,
283+
tokenizer=tokenizer,
284+
weights=weights,
285+
token_mapping=token_mapping,
286+
config=config,
277287
normalize=normalize,
278288
base_model_name=metadata.get("base_model"),
279289
language=metadata.get("language"),
@@ -446,10 +456,11 @@ def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.
446456
out: list[np.ndarray] = []
447457
for id_list in ids:
448458
if id_list:
459+
id_list_remapped: list[int] | np.ndarray
449460
if self.token_mapping is None:
450461
id_list_remapped = id_list
451462
else:
452-
id_list_remapped = [self.token_mapping.get(token_id, token_id) for token_id in id_list]
463+
id_list_remapped = self.token_mapping[id_list]
453464
emb = self.embedding[id_list_remapped]
454465
if self.weights is not None:
455466
emb = emb * self.weights[id_list][:, None]
@@ -512,6 +523,9 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
512523
if not path.is_dir():
513524
raise ValueError(f"Path {path} is not a directory.")
514525

515-
embeddings, tokenizer, config = load_local_model(path)
526+
embeddings, tokenizer, config, weights = load_local_model(path)
527+
token_mapping = cast(list[int], config.pop("token_mapping", None))
516528

517-
return StaticModel(embeddings, tokenizer, config=config)
529+
return StaticModel(
530+
vectors=embeddings, tokenizer=tokenizer, config=config, weights=weights, token_mapping=token_mapping
531+
)

model2vec/quantization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def quantize_and_reduce_dim(
6464
return embeddings
6565

6666

67-
def vocabulary_quantization(
67+
def quantize_vocabulary(
6868
n_clusters: int, weights: np.ndarray | None, embeddings: np.ndarray
69-
) -> tuple[np.ndarray, dict[int, int], np.ndarray]:
69+
) -> tuple[np.ndarray, list[int], np.ndarray]:
7070
"""Quantize the vocabulary of embeddings using KMeans clustering."""
7171
# If the model does not have weights, we assume the norm to be informative.
7272
if weights is None:
@@ -80,7 +80,8 @@ def vocabulary_quantization(
8080
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
8181
kmeans.fit(embeddings)
8282
# Create a mapping from the original token index to the cluster index
83-
token_mapping = {idx: x for idx, x in enumerate(kmeans.predict(embeddings))}
83+
# Make sure to convert to list, otherwise we get np.int32 which is not jsonable.
84+
token_mapping = cast(list[int], kmeans.predict(embeddings).tolist())
8485
# The cluster centers are the new embeddings.
8586
embeddings = kmeans.cluster_centers_
8687

model2vec/train/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(
3333
:param tokenizer: The tokenizer.
3434
:param out_dim: The output dimension of the head.
3535
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
36+
:param token_mapping: The token mapping. If None, the token mapping is set to the range of the number of vectors.
37+
:param weights: The weights of the model. If None, the weights are initialized to zeros.
3638
"""
3739
super().__init__()
3840
self.pad_id = pad_id
@@ -82,7 +84,7 @@ def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int
8284
weights = torch.from_numpy(model.weights) if model.weights is not None else None
8385
embeddings_converted = torch.from_numpy(model.embedding)
8486
if model.token_mapping is not None:
85-
token_mapping = [i for _, i in sorted(model.token_mapping.items(), key=lambda x: x[0])]
87+
token_mapping = model.token_mapping.tolist()
8688
else:
8789
token_mapping = None
8890
return cls(
@@ -148,7 +150,7 @@ def to_static_model(self) -> StaticModel:
148150
"""Convert the model to a static model."""
149151
emb = self.embeddings.weight.detach().cpu().numpy()
150152
w = torch.sigmoid(self.w).detach().cpu().numpy()
151-
token_mapping = {i: int(token_id) for i, token_id in enumerate(self.token_mapping.tolist())}
153+
token_mapping = self.token_mapping.tolist()
152154

153155
return StaticModel(
154156
vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=True, token_mapping=token_mapping

0 commit comments

Comments
 (0)