Skip to content

Commit 3880632

Browse files
committed
fix typing/linting
1 parent 9d41dfc commit 3880632

File tree

10 files changed

+265
-103
lines changed

10 files changed

+265
-103
lines changed

model2vec/distill/distillation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
1515
from model2vec.distill.utils import select_optimal_device
1616
from model2vec.model import StaticModel
17-
from model2vec.quantization import DType, quantize_embeddings, quantize_vocabulary
17+
from model2vec.quantization import DType, quantize_embeddings
1818
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
19+
from model2vec.vocabulary_quantization import quantize_vocabulary
1920

2021
logger = logging.getLogger(__name__)
2122

model2vec/model.py

Lines changed: 136 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tokenizers import Encoding, Tokenizer
1313
from tqdm import tqdm
1414

15-
from model2vec.quantization import DType, quantize_and_reduce_dim, quantize_vocabulary
15+
from model2vec.quantization import DType
1616
from model2vec.utils import ProgressParallel, load_local_model
1717

1818
PathLike = Union[Path, str]
@@ -63,7 +63,7 @@ def __init__(
6363
self.weights = weights
6464
# Convert to an array for fast lookups
6565
# We can't use or short circuit here because np.ndarray as booleans are ambiguous.
66-
self.token_mapping = None if token_mapping is None else np.asarray(token_mapping)
66+
self.token_mapping: np.ndarray | None = None if token_mapping is None else np.asarray(token_mapping)
6767

6868
self.tokenizer = tokenizer
6969
self.unk_token_id: int | None
@@ -194,39 +194,16 @@ def from_pretrained(
194194
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization.
195195
:return: A StaticModel.
196196
"""
197-
from model2vec.hf_utils import load_pretrained
198-
199-
embeddings, tokenizer, config, metadata, weights = load_pretrained(
200-
folder_or_repo_path=path,
197+
return _loading_helper(
198+
cls=cls,
199+
path=path,
201200
token=token,
202-
from_sentence_transformers=False,
203-
subfolder=subfolder,
204-
)
205-
206-
# Quantize the vocabulary at full precision and dimensionality
207-
if vocabulary_quantization is not None:
208-
embeddings, token_mapping, weights = quantize_vocabulary(
209-
n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings
210-
)
211-
else:
212-
token_mapping = config.pop("token_mapping", None)
213-
214-
# Reduce dimensionality and quantize if requested
215-
embeddings = quantize_and_reduce_dim(
216-
embeddings=embeddings,
201+
vocabulary_quantization=vocabulary_quantization,
217202
quantize_to=quantize_to,
218203
dimensionality=dimensionality,
219-
)
220-
221-
return cls(
222-
vectors=embeddings,
223-
tokenizer=tokenizer,
224-
weights=weights,
225-
token_mapping=token_mapping,
226-
config=config,
204+
from_sentence_transformers=False,
227205
normalize=normalize,
228-
base_model_name=metadata.get("base_model"),
229-
language=metadata.get("language"),
206+
subfolder=subfolder,
230207
)
231208

232209
@classmethod
@@ -255,38 +232,16 @@ def from_sentence_transformers(
255232
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization.
256233
:return: A StaticModel.
257234
"""
258-
from model2vec.hf_utils import load_pretrained
259-
260-
embeddings, tokenizer, config, metadata, weights = load_pretrained(
261-
folder_or_repo_path=path,
235+
return _loading_helper(
236+
cls=cls,
237+
path=path,
262238
token=token,
263-
from_sentence_transformers=True,
264-
)
265-
266-
# Quantize the vocabulary at full precision and dimensionality
267-
if vocabulary_quantization is not None:
268-
embeddings, token_mapping, weights = quantize_vocabulary(
269-
n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings
270-
)
271-
else:
272-
token_mapping = config.pop("token_mapping", None)
273-
274-
# Reduce dimensionality and quantize if requested
275-
embeddings = quantize_and_reduce_dim(
276-
embeddings=embeddings,
239+
vocabulary_quantization=vocabulary_quantization,
277240
quantize_to=quantize_to,
278241
dimensionality=dimensionality,
279-
)
280-
281-
return cls(
282-
vectors=embeddings,
283-
tokenizer=tokenizer,
284-
weights=weights,
285-
token_mapping=token_mapping,
286-
config=config,
242+
from_sentence_transformers=True,
287243
normalize=normalize,
288-
base_model_name=metadata.get("base_model"),
289-
language=metadata.get("language"),
244+
subfolder=None,
290245
)
291246

292247
@overload
@@ -381,7 +336,7 @@ def _encode_batch_as_sequence(self, sentences: Sequence[str], max_length: int |
381336
out: list[np.ndarray] = []
382337
for id_list in ids:
383338
if id_list:
384-
out.append(self.embedding[id_list])
339+
out.append(self._encode_helper(id_list))
385340
else:
386341
out.append(np.zeros((0, self.dim)))
387342

@@ -450,23 +405,35 @@ def encode(
450405
return out_array[0]
451406
return out_array
452407

408+
def _encode_helper(self, id_list: list[int]) -> np.ndarray:
409+
"""
410+
Helper function to encode a list of ids.
411+
412+
This function is used to deduplicate the logic in `encode` and `encode_as_sequence`.
413+
It retrieves the embeddings for the given list of ids, applying weights if available.
414+
415+
:param id_list: A list of token ids.
416+
:return: The embeddings for the given ids, as a sequence of vectors.
417+
"""
418+
id_list_remapped: list[int] | np.ndarray
419+
if self.token_mapping is None:
420+
id_list_remapped = id_list
421+
else:
422+
id_list_remapped = self.token_mapping[id_list]
423+
emb = self.embedding[id_list_remapped]
424+
if self.weights is not None:
425+
emb = emb * self.weights[id_list][:, None]
426+
427+
return emb
428+
453429
def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.ndarray:
454430
"""Encode a batch of sentences."""
455431
ids = self.tokenize(sentences=sentences, max_length=max_length)
456432
out: list[np.ndarray] = []
457433
for id_list in ids:
458434
if id_list:
459-
id_list_remapped: list[int] | np.ndarray
460-
if self.token_mapping is None:
461-
id_list_remapped = id_list
462-
else:
463-
id_list_remapped = self.token_mapping[id_list]
464-
emb = self.embedding[id_list_remapped]
465-
if self.weights is not None:
466-
emb = emb * self.weights[id_list][:, None]
467-
emb = emb.mean(axis=0)
468-
469-
out.append(emb)
435+
emb = self._encode_helper(id_list)
436+
out.append(emb.mean(axis=0))
470437
else:
471438
out.append(np.zeros(self.dim))
472439

@@ -529,3 +496,101 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
529496
return StaticModel(
530497
vectors=embeddings, tokenizer=tokenizer, config=config, weights=weights, token_mapping=token_mapping
531498
)
499+
500+
501+
def quantize_model(
502+
model: StaticModel,
503+
vocabulary_quantization: int | None = None,
504+
quantize_to: str | DType | None = None,
505+
dimensionality: int | None = None,
506+
) -> StaticModel:
507+
"""
508+
Quantize the model to a lower precision and possibly lower dimensionality.
509+
510+
:param model: The model to quantize.
511+
:param vocabulary_quantization: The number of clusters to use for quantization.
512+
:param quantize_to: The dtype to quantize the model to.
513+
:param dimensionality: The desired dimensionality of the model.
514+
This needs to be < than the current model dimensionality.
515+
:return: A new StaticModel with the quantized embeddings.
516+
:raises: ValueError if the model is already quantized.
517+
"""
518+
from model2vec.quantization import quantize_and_reduce_dim
519+
520+
token_mapping: list[int] | None
521+
weights: np.ndarray | None
522+
if vocabulary_quantization is not None:
523+
from model2vec.vocabulary_quantization import quantize_vocabulary
524+
525+
if len(model.tokens) != len(model.embedding):
526+
raise ValueError("Model already has been vocabulary quantized, cannot quantize again.")
527+
528+
embeddings, token_mapping, weights = quantize_vocabulary(
529+
n_clusters=vocabulary_quantization, weights=model.weights, embeddings=model.embedding
530+
)
531+
else:
532+
embeddings = model.embedding
533+
token_mapping = cast(list[int], model.token_mapping.tolist()) if model.token_mapping is not None else None
534+
weights = model.weights
535+
if quantize_to is not None or dimensionality is not None:
536+
embeddings = quantize_and_reduce_dim(
537+
embeddings=embeddings,
538+
quantize_to=quantize_to,
539+
dimensionality=dimensionality,
540+
)
541+
542+
return StaticModel(
543+
vectors=embeddings,
544+
tokenizer=model.tokenizer,
545+
config=model.config,
546+
weights=weights,
547+
token_mapping=token_mapping,
548+
normalize=model.normalize,
549+
base_model_name=model.base_model_name,
550+
language=model.language,
551+
)
552+
553+
554+
def _loading_helper(
555+
cls: type[StaticModel],
556+
path: PathLike,
557+
token: str | None,
558+
vocabulary_quantization: int | None = None,
559+
quantize_to: str | DType | None = None,
560+
dimensionality: int | None = None,
561+
from_sentence_transformers: bool = False,
562+
normalize: bool | None = None,
563+
subfolder: str | None = None,
564+
) -> StaticModel:
565+
"""Helper function to load a model from a directory."""
566+
from model2vec.hf_utils import load_pretrained
567+
568+
if from_sentence_transformers and subfolder is not None:
569+
raise ValueError("Subfolder is not supported for sentence transformers models.")
570+
571+
embeddings, tokenizer, config, metadata, weights = load_pretrained(
572+
folder_or_repo_path=path,
573+
token=token,
574+
from_sentence_transformers=from_sentence_transformers,
575+
subfolder=subfolder,
576+
)
577+
578+
token_mapping = config.pop("token_mapping", None)
579+
580+
model = cls(
581+
vectors=embeddings,
582+
tokenizer=tokenizer,
583+
weights=weights,
584+
token_mapping=token_mapping,
585+
config=config,
586+
normalize=normalize,
587+
base_model_name=metadata.get("base_model"),
588+
language=metadata.get("language"),
589+
)
590+
591+
return quantize_model(
592+
model=model,
593+
vocabulary_quantization=vocabulary_quantization,
594+
quantize_to=quantize_to,
595+
dimensionality=dimensionality,
596+
)

model2vec/quantization.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import logging
34
from enum import Enum
4-
from typing import cast
55

66
import numpy as np
77

8+
logger = logging.getLogger(__name__)
9+
810

911
class DType(str, Enum):
1012
Float16 = "float16"
@@ -62,27 +64,3 @@ def quantize_and_reduce_dim(
6264
embeddings = embeddings[:, :dimensionality]
6365

6466
return embeddings
65-
66-
67-
def quantize_vocabulary(
68-
n_clusters: int, weights: np.ndarray | None, embeddings: np.ndarray
69-
) -> tuple[np.ndarray, list[int], np.ndarray]:
70-
"""Quantize the vocabulary of embeddings using KMeans clustering."""
71-
# If the model does not have weights, we assume the norm to be informative.
72-
if weights is None:
73-
weights = cast(np.ndarray, np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-32)
74-
# Divide by the norm to normalize the embeddings, so we don't bias the clustering.
75-
embeddings = embeddings / weights
76-
77-
# Quantize the vocabulary
78-
from sklearn.cluster import KMeans
79-
80-
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
81-
kmeans.fit(embeddings)
82-
# Create a mapping from the original token index to the cluster index
83-
# Make sure to convert to list, otherwise we get np.int32 which is not jsonable.
84-
token_mapping = cast(list[int], kmeans.predict(embeddings).tolist())
85-
# The cluster centers are the new embeddings.
86-
embeddings = kmeans.cluster_centers_
87-
88-
return embeddings, token_mapping, weights

model2vec/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,4 @@ def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str
125125

126126
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
127127

128-
if len(tokenizer.get_vocab()) != len(embeddings):
129-
logger.warning(
130-
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
131-
)
132128
return embeddings, tokenizer, config, weights
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
from typing import cast
3+
4+
import numpy as np
5+
6+
# Lazy import
7+
try:
8+
from sklearn.cluster import KMeans
9+
except ImportError:
10+
raise ImportError(
11+
"scikit-learn is required for quantizing the vocabulary. "
12+
"Please install model2vec with the quantization extra."
13+
)
14+
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def quantize_vocabulary(
20+
n_clusters: int, weights: np.ndarray | None, embeddings: np.ndarray
21+
) -> tuple[np.ndarray, list[int], np.ndarray]:
22+
"""Quantize the vocabulary of embeddings using KMeans clustering."""
23+
logger.info(f"Quantizing vocabulary to {n_clusters} clusters.")
24+
# If the model does not have weights, we assume the norm to be informative.
25+
if weights is None:
26+
weights = cast(np.ndarray, np.linalg.norm(embeddings, axis=1) + 1e-32)
27+
# Divide by the norm to normalize the embeddings, so we don't bias the clustering.
28+
embeddings = embeddings / weights[:, None]
29+
30+
# Ensure the embeddings are in float32 for KMeans
31+
# Store the original dtype to restore it later
32+
orig_dtype = embeddings.dtype
33+
34+
kmeans = KMeans(n_clusters=n_clusters, random_state=42, init="k-means++")
35+
cast_embeddings = embeddings.astype(np.float32)
36+
# Fit KMeans to the embeddings
37+
kmeans.fit(cast_embeddings)
38+
# Create a mapping from the original token index to the cluster index
39+
# Make sure to convert to list, otherwise we get np.int32 which is not jsonable.
40+
token_mapping = cast(list[int], kmeans.predict(cast_embeddings).tolist())
41+
# The cluster centers are the new embeddings.
42+
# Convert them back to the original dtype
43+
embeddings = kmeans.cluster_centers_.astype(orig_dtype)
44+
45+
return embeddings, token_mapping, weights

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ onnx = ["onnx", "torch"]
6565
train = ["torch", "lightning", "scikit-learn", "skops"]
6666
inference = ["scikit-learn", "skops"]
6767
tokenizer = ["transformers"]
68+
quantization = ["scikit-learn"]
6869

6970
[project.urls]
7071
"Homepage" = "https://github.com/MinishLab"

0 commit comments

Comments
 (0)