Skip to content

Commit b8da1b2

Browse files
committed
Merge branch 'bump-version' of https://github.com/MinishLab/model2vec into bump-version
2 parents 194508e + e3f58b9 commit b8da1b2

File tree

12 files changed

+1240
-634
lines changed

12 files changed

+1240
-634
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ install:
1010

1111
install-no-pre-commit:
1212
uv pip install ".[dev,distill]"
13+
uv pip install "torch<2.5.0"
1314

1415
install-base:
1516
uv sync --extra dev

README.md

Lines changed: 167 additions & 131 deletions
Large diffs are not rendered by default.

model2vec/distill/distillation.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4+
import re
45
from typing import Literal, Union
56

67
import numpy as np
78
from huggingface_hub import model_info
89
from sklearn.decomposition import PCA
10+
from tokenizers import Tokenizer
911
from tokenizers.models import BPE, Unigram
1012
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
1113

@@ -39,6 +41,7 @@ def distill_from_model(
3941
pca_dims: PCADimType = 256,
4042
apply_zipf: bool = True,
4143
use_subword: bool = True,
44+
token_remove_pattern: str | None = r"\[unused\d+\]",
4245
) -> StaticModel:
4346
"""
4447
Distill a staticmodel from a sentence transformer.
@@ -58,8 +61,12 @@ def distill_from_model(
5861
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
5962
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
6063
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
64+
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
65+
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
6166
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
6267
:raises: ValueError if the vocabulary contains duplicate tokens.
68+
:raises: ValueError if the regex can't be compiled.
69+
:raises: ValueError if the vocabulary is empty after token removal.
6370
:return: A StaticModel
6471
6572
"""
@@ -81,17 +88,7 @@ def distill_from_model(
8188
if use_subword:
8289
# Create the subword embeddings.
8390
tokens, embeddings = create_output_embeddings_from_model_name(model=model, tokenizer=tokenizer, device=device)
84-
85-
# Remove any unused tokens from the tokenizer and embeddings.
86-
wrong_tokens = [x for x in tokens if x.startswith("[unused")]
87-
vocab = tokenizer.get_vocab()
88-
# Get the ids of the unused token.
89-
wrong_token_ids = [vocab[token] for token in wrong_tokens]
90-
# Remove the unused tokens from the tokenizer.
91-
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
92-
# Remove the embeddings of the unused tokens.
93-
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
94-
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
91+
new_tokenizer, embeddings = _remove_tokens_and_embeddings(tokenizer, token_remove_pattern, tokens, embeddings)
9592
else:
9693
# We need to keep the unk token in the tokenizer.
9794
unk_token = tokenizer.backend_tokenizer.model.unk_token
@@ -136,6 +133,8 @@ def distill_from_model(
136133
model_name = getattr(model, "name_or_path", "")
137134

138135
config = {
136+
"model_type": "model2vec",
137+
"architectures": ["StaticModel"],
139138
"tokenizer_name": model_name,
140139
"apply_pca": pca_dims,
141140
"apply_zipf": apply_zipf,
@@ -155,6 +154,37 @@ def distill_from_model(
155154
)
156155

157156

157+
def _remove_tokens_and_embeddings(
158+
tokenizer: PreTrainedTokenizerFast, token_remove_pattern: str | None, tokens: list[str], embeddings: np.ndarray
159+
) -> tuple[Tokenizer, np.ndarray]:
160+
if not token_remove_pattern:
161+
return tokenizer.backend_tokenizer, embeddings
162+
163+
try:
164+
token_regex = re.compile(token_remove_pattern)
165+
except re.error as e:
166+
raise ValueError(f"Invalid regex pattern: {token_remove_pattern}") from e
167+
# Remove any unused tokens from the tokenizer and embeddings.
168+
wrong_tokens = [x for x in tokens if token_regex.match(x)]
169+
vocab = tokenizer.get_vocab()
170+
# Get the ids of the unused token.
171+
wrong_token_ids = [vocab[token] for token in wrong_tokens]
172+
173+
if len(wrong_token_ids) == len(vocab):
174+
raise ValueError(
175+
"All tokens in the vocabulary are unused tokens. This will result in an empty tokenizer. "
176+
"Please provide a valid token removal pattern. The pattern is now: {token_remove_pattern}"
177+
)
178+
179+
# Remove the unused tokens from the tokenizer.
180+
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
181+
# Remove the embeddings of the unused tokens.
182+
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
183+
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
184+
185+
return new_tokenizer, embeddings
186+
187+
158188
def distill(
159189
model_name: str,
160190
vocabulary: list[str] | None = None,

model2vec/distill/tokenizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,12 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
6868
tokenizer_data["model"]["vocab"] = reindexed
6969

7070
elif model_type == "Unigram":
71-
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")
71+
logger.warning("Removing tokens from a unigram tokenizer is not supported.")
72+
return tokenizer
7273

7374
elif model_type == "BPE":
74-
raise ValueError("Removing tokens from a BPE tokenizer is not supported.")
75+
logger.warning("Removing tokens from a BPE tokenizer is not supported.")
76+
return tokenizer
7577

7678
else:
7779
raise ValueError(f"Unknown model type {model_type}")

model2vec/model.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from __future__ import annotations
22

33
import math
4+
import os
45
from logging import getLogger
56
from pathlib import Path
67
from tempfile import TemporaryDirectory
78
from typing import Any, Iterator, Union
89

910
import numpy as np
11+
from joblib import delayed
1012
from tokenizers import Encoding, Tokenizer
1113
from tqdm import tqdm
1214

13-
from model2vec.utils import load_local_model
15+
from model2vec.utils import ProgressParallel, load_local_model
1416

1517
PathLike = Union[Path, str]
1618

17-
1819
logger = getLogger(__name__)
1920

2021

@@ -171,6 +172,8 @@ def encode_as_sequence(
171172
max_length: int | None = None,
172173
batch_size: int = 1024,
173174
show_progress_bar: bool = False,
175+
use_multiprocessing: bool = True,
176+
multiprocessing_threshold: int = 10_000,
174177
) -> list[np.ndarray] | np.ndarray:
175178
"""
176179
Encode a list of sentences as a list of numpy arrays of tokens.
@@ -186,24 +189,42 @@ def encode_as_sequence(
186189
If this is None, no truncation is done.
187190
:param batch_size: The batch size to use.
188191
:param show_progress_bar: Whether to show the progress bar.
192+
:param use_multiprocessing: Whether to use multiprocessing.
193+
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
194+
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
189195
:return: The encoded sentences with an embedding per token.
190196
"""
191197
was_single = False
192198
if isinstance(sentences, str):
193199
sentences = [sentences]
194200
was_single = True
195201

196-
out_array: list[np.ndarray] = []
197-
for batch in tqdm(
198-
self._batch(sentences, batch_size),
199-
total=math.ceil(len(sentences) / batch_size),
200-
disable=not show_progress_bar,
201-
):
202-
out_array.extend(self._encode_batch_as_sequence(batch, max_length))
202+
# Prepare all batches
203+
sentence_batches = list(self._batch(sentences, batch_size))
204+
total_batches = math.ceil(len(sentences) / batch_size)
205+
206+
# Use joblib for multiprocessing if requested, and if we have enough sentences
207+
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
208+
# Disable parallelism for tokenizers
209+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
210+
211+
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
212+
delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
213+
)
214+
out_array: list[np.ndarray] = []
215+
for r in results:
216+
out_array.extend(r)
217+
else:
218+
out_array = []
219+
for batch in tqdm(
220+
sentence_batches,
221+
total=total_batches,
222+
disable=not show_progress_bar,
223+
):
224+
out_array.extend(self._encode_batch_as_sequence(batch, max_length))
203225

204226
if was_single:
205227
return out_array[0]
206-
207228
return out_array
208229

209230
def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None) -> list[np.ndarray]:
@@ -224,6 +245,8 @@ def encode(
224245
show_progress_bar: bool = False,
225246
max_length: int | None = 512,
226247
batch_size: int = 1024,
248+
use_multiprocessing: bool = True,
249+
multiprocessing_threshold: int = 10_000,
227250
**kwargs: Any,
228251
) -> np.ndarray:
229252
"""
@@ -237,6 +260,9 @@ def encode(
237260
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
238261
If this is None, no truncation is done.
239262
:param batch_size: The batch size to use.
263+
:param use_multiprocessing: Whether to use multiprocessing.
264+
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
265+
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
240266
:param **kwargs: Any additional arguments. These are ignored.
241267
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
242268
"""
@@ -245,19 +271,32 @@ def encode(
245271
sentences = [sentences]
246272
was_single = True
247273

248-
out_arrays: list[np.ndarray] = []
249-
for batch in tqdm(
250-
self._batch(sentences, batch_size),
251-
total=math.ceil(len(sentences) / batch_size),
252-
disable=not show_progress_bar,
253-
):
254-
out_arrays.append(self._encode_batch(batch, max_length))
274+
# Prepare all batches
275+
sentence_batches = list(self._batch(sentences, batch_size))
276+
total_batches = math.ceil(len(sentences) / batch_size)
255277

256-
out_array = np.concatenate(out_arrays, axis=0)
278+
# Use joblib for multiprocessing if requested, and if we have enough sentences
279+
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
280+
# Disable parallelism for tokenizers
281+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
282+
283+
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(
284+
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
285+
)
286+
out_array = np.concatenate(results, axis=0)
287+
else:
288+
# Don't use multiprocessing
289+
out_arrays: list[np.ndarray] = []
290+
for batch in tqdm(
291+
sentence_batches,
292+
total=total_batches,
293+
disable=not show_progress_bar,
294+
):
295+
out_arrays.append(self._encode_batch(batch, max_length))
296+
out_array = np.concatenate(out_arrays, axis=0)
257297

258298
if was_single:
259299
return out_array[0]
260-
261300
return out_array
262301

263302
def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:

model2vec/utils.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,56 @@
11
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
24
import json
35
import logging
6+
import re
47
from importlib import import_module
58
from importlib.metadata import metadata
69
from pathlib import Path
7-
from typing import Iterator, Protocol, cast
10+
from typing import Any, Iterator, Protocol, cast
811

912
import numpy as np
1013
import safetensors
14+
from joblib import Parallel
1115
from tokenizers import Tokenizer
16+
from tqdm import tqdm
1217

1318
logger = logging.getLogger(__name__)
1419

1520

21+
class ProgressParallel(Parallel):
22+
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""
23+
24+
def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
25+
"""
26+
Initialize the ProgressParallel object.
27+
28+
:param use_tqdm: Whether to show the progress bar.
29+
:param total: Total number of tasks (batches) you expect to process. If None,
30+
it updates the total dynamically to the number of dispatched tasks.
31+
:param *args: Additional arguments to pass to `Parallel.__init__`.
32+
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
33+
"""
34+
self._use_tqdm = use_tqdm
35+
self._total = total
36+
super().__init__(*args, **kwargs)
37+
38+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
39+
"""Create a tqdm context."""
40+
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
41+
self._pbar = self._pbar
42+
return super().__call__(*args, **kwargs)
43+
44+
def print_progress(self) -> None:
45+
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
46+
if self._total is None:
47+
# If no fixed total was given, we dynamically set the total
48+
self._pbar.total = self.n_dispatched_tasks
49+
# Move the bar to the number of completed tasks
50+
self._pbar.n = self.n_completed_tasks
51+
self._pbar.refresh()
52+
53+
1654
class SafeOpenProtocol(Protocol):
1755
"""Protocol to fix safetensors safe open."""
1856

@@ -22,6 +60,7 @@ def get_tensor(self, key: str) -> np.ndarray:
2260

2361

2462
_MODULE_MAP = (("scikit-learn", "sklearn"),)
63+
_DIVIDERS = re.compile(r"[=<>!]+")
2564

2665

2766
def get_package_extras(package: str, extra: str) -> Iterator[str]:
@@ -38,7 +77,8 @@ def get_package_extras(package: str, extra: str) -> Iterator[str]:
3877
# Extract and clean the extra requirement
3978
found_extra = rest[0].split("==")[-1].strip(" \"'")
4079
if found_extra == extra:
41-
yield name.strip()
80+
prefix, *_ = _DIVIDERS.split(name)
81+
yield prefix.strip()
4282

4383

4484
def importable(module: str, extra: str) -> None:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424

2525
dependencies = [
2626
"jinja2",
27+
"joblib",
2728
"numpy",
2829
"rich",
2930
"safetensors",
@@ -50,11 +51,10 @@ dev = [
5051
"mypy",
5152
"pre-commit",
5253
"pytest",
53-
"pytest-coverage",
54+
"pytest-cov",
5455
"ruff",
5556
]
5657
distill = ["torch", "transformers", "scikit-learn"]
57-
5858
onnx = ["onnx", "torch"]
5959

6060
[project.urls]

0 commit comments

Comments
 (0)