Skip to content

Commit 4f6b716

Browse files
authored
chore: Deprecate apply_zipf and use_subword parameters (#289)
* Depcrecated apply_zipf and use_subword parameters * Depcrecated apply_zipf and use_subword parameters
1 parent 0609f5e commit 4f6b716

File tree

2 files changed

+40
-107
lines changed

2 files changed

+40
-107
lines changed

model2vec/distill/distillation.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ def distill_from_model(
2727
vocabulary: list[str] | None = None,
2828
device: str | None = None,
2929
pca_dims: PCADimType = 256,
30-
apply_zipf: bool | None = None,
3130
sif_coefficient: float | None = 1e-4,
3231
token_remove_pattern: str | None = r"\[unused\d+\]",
3332
quantize_to: DType | str = DType.Float16,
34-
use_subword: bool | None = None,
3533
vocabulary_quantization: int | None = None,
3634
pooling: PoolingType = PoolingType.MEAN,
3735
) -> StaticModel:
@@ -51,14 +49,11 @@ def distill_from_model(
5149
:param pca_dims: The number of components to use for PCA.
5250
If this is None, we don't apply PCA.
5351
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
54-
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
55-
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
5652
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
5753
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
5854
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
5955
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
6056
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
61-
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
6257
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
6358
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
6459
'mean' (default): mean over all tokens. Robust and works well in most cases.
@@ -69,13 +64,9 @@ def distill_from_model(
6964
:raises: ValueError if the vocabulary is empty after preprocessing.
7065
7166
"""
72-
if use_subword is not None:
73-
logger.warning(
74-
"The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
75-
)
7667
quantize_to = DType(quantize_to)
7768
backend_tokenizer = tokenizer.backend_tokenizer
78-
sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
69+
sif_coefficient, token_remove_regex = _validate_parameters(sif_coefficient, token_remove_pattern)
7970

8071
if vocabulary is None:
8172
vocabulary = []
@@ -147,7 +138,6 @@ def distill_from_model(
147138
"architectures": ["StaticModel"],
148139
"tokenizer_name": model_name,
149140
"apply_pca": pca_dims,
150-
"apply_zipf": apply_zipf,
151141
"sif_coefficient": sif_coefficient,
152142
"hidden_dim": embeddings.shape[1],
153143
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
@@ -182,35 +172,19 @@ def distill_from_model(
182172

183173

184174
def _validate_parameters(
185-
apply_zipf: bool | None,
186175
sif_coefficient: float | None,
187176
token_remove_pattern: str | None,
188177
) -> tuple[float | None, re.Pattern | None]:
189178
"""
190179
Validate the parameters passed to the distillation function.
191180
192-
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
193-
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
194181
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
195182
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
196183
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
197184
:return: The SIF coefficient to use.
198185
:raises: ValueError if the regex can't be compiled.
199186
200187
"""
201-
if apply_zipf is not None:
202-
logger.warning(
203-
"The `apply_zipf` parameter is deprecated and will be removed in the next release. "
204-
"Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
205-
"no weighting is applied."
206-
)
207-
if apply_zipf and sif_coefficient is None:
208-
logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
209-
sif_coefficient = 1e-4
210-
elif not apply_zipf:
211-
logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
212-
sif_coefficient = None
213-
214188
if sif_coefficient is not None:
215189
if not 0 < sif_coefficient < 1.0:
216190
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")
@@ -230,12 +204,10 @@ def distill(
230204
vocabulary: list[str] | None = None,
231205
device: str | None = None,
232206
pca_dims: PCADimType = 256,
233-
apply_zipf: bool | None = None,
234207
sif_coefficient: float | None = 1e-4,
235208
token_remove_pattern: str | None = r"\[unused\d+\]",
236209
trust_remote_code: bool = False,
237210
quantize_to: DType | str = DType.Float16,
238-
use_subword: bool | None = None,
239211
vocabulary_quantization: int | None = None,
240212
pooling: PoolingType = PoolingType.MEAN,
241213
) -> StaticModel:
@@ -254,14 +226,11 @@ def distill(
254226
:param pca_dims: The number of components to use for PCA.
255227
If this is None, we don't apply PCA.
256228
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
257-
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
258-
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
259229
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
260230
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
261231
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
262232
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
263233
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
264-
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
265234
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
266235
:param pooling: The pooling strategy to use for creating embeddings. Can be one of:
267236
'mean' (default): mean over all tokens. Robust and works well in most cases.
@@ -283,11 +252,9 @@ def distill(
283252
vocabulary=vocabulary,
284253
device=device,
285254
pca_dims=pca_dims,
286-
apply_zipf=apply_zipf,
287255
token_remove_pattern=token_remove_pattern,
288256
sif_coefficient=sif_coefficient,
289257
quantize_to=quantize_to,
290-
use_subword=use_subword,
291258
vocabulary_quantization=vocabulary_quantization,
292259
pooling=pooling,
293260
)

tests/test_distillation.py

Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: utf-8 -*-
12
from __future__ import annotations
23

34
import json
@@ -10,14 +11,10 @@
1011
from transformers import BertTokenizerFast
1112
from transformers.modeling_utils import PreTrainedModel
1213

13-
from model2vec.distill.distillation import (
14-
clean_and_create_vocabulary,
15-
distill,
16-
distill_from_model,
17-
post_process_embeddings,
18-
)
19-
from model2vec.distill.inference import PoolingType, create_embeddings
14+
from model2vec.distill.distillation import distill, distill_from_model
15+
from model2vec.distill.inference import PoolingType, create_embeddings, post_process_embeddings
2016
from model2vec.model import StaticModel
17+
from model2vec.tokenizer import clean_and_create_vocabulary
2118

2219
try:
2320
# For huggingface_hub>=0.25.0
@@ -30,14 +27,14 @@
3027

3128

3229
@pytest.mark.parametrize(
33-
"vocabulary, pca_dims, apply_zipf",
30+
"vocabulary, pca_dims, sif_coefficient",
3431
[
35-
(None, 256, True), # Output vocab with subwords, PCA applied
36-
(["wordA", "wordB"], 4, False), # Custom vocab with subword, PCA applied
37-
(None, "auto", False), # Subword, PCA set to 'auto'
38-
(None, 1024, False), # Subword, PCA set to high number.
39-
(None, None, True), # No PCA applied
40-
(None, 0.9, True), # PCA as float applied
32+
(None, 256, 1e-4), # Subword vocab, PCA applied, SIF on
33+
(["wordA", "wordB"], 4, None), # Custom vocab, PCA applied, SIF off
34+
(None, "auto", None), # Subword, PCA 'auto', SIF off
35+
(None, 1024, None), # Subword, PCA set high, SIF off
36+
(None, None, 1e-4), # No PCA, SIF on
37+
(None, 0.9, 1e-4), # PCA as float (variance), SIF on
4138
],
4239
)
4340
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@@ -49,24 +46,20 @@ def test_distill_from_model(
4946
mock_transformer: PreTrainedModel,
5047
vocabulary: list[str] | None,
5148
pca_dims: int | None,
52-
apply_zipf: bool,
49+
sif_coefficient: float | None,
5350
) -> None:
5451
"""Test distill function with different parameters."""
5552
# Mock the return value of model_info to avoid calling the Hugging Face API
5653
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
57-
58-
# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
59-
# mock_auto_tokenizer.return_value = mock_berttokenizer
6054
mock_auto_model.return_value = mock_transformer
6155

62-
# Call the distill function with the parametrized inputs
6356
static_model = distill_from_model(
6457
model=mock_transformer,
6558
tokenizer=mock_berttokenizer,
6659
vocabulary=vocabulary,
6760
device="cpu",
6861
pca_dims=pca_dims,
69-
apply_zipf=apply_zipf,
62+
sif_coefficient=sif_coefficient,
7063
token_remove_pattern=None,
7164
)
7265

@@ -75,7 +68,7 @@ def test_distill_from_model(
7568
vocabulary=vocabulary,
7669
device="cpu",
7770
pca_dims=pca_dims,
78-
apply_zipf=apply_zipf,
71+
sif_coefficient=sif_coefficient,
7972
token_remove_pattern=None,
8073
)
8174

@@ -94,11 +87,7 @@ def test_distill_removal_pattern(
9487
mock_transformer: PreTrainedModel,
9588
) -> None:
9689
"""Test the removal pattern."""
97-
# Mock the return value of model_info to avoid calling the Hugging Face API
9890
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
99-
100-
# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
101-
# mock_auto_tokenizer.return_value = mock_berttokenizer
10291
mock_auto_model.return_value = mock_transformer
10392

10493
# The vocab size is 30522, but we remove 998 tokens: [CLS], [SEP], and [MASK], and all [unused] tokens.
@@ -111,7 +100,6 @@ def test_distill_removal_pattern(
111100
device="cpu",
112101
token_remove_pattern=None,
113102
)
114-
115103
assert len(static_model.embedding) == expected_vocab_size
116104

117105
# No tokens removed, nonsensical pattern
@@ -122,12 +110,11 @@ def test_distill_removal_pattern(
122110
device="cpu",
123111
token_remove_pattern="£££££££££££££££££",
124112
)
125-
126113
assert len(static_model.embedding) == expected_vocab_size
127114

128115
# Weird pattern.
129116
with pytest.raises(ValueError):
130-
static_model = distill_from_model(
117+
_ = distill_from_model(
131118
model=mock_transformer,
132119
tokenizer=mock_berttokenizer,
133120
vocabulary=None,
@@ -137,19 +124,16 @@ def test_distill_removal_pattern(
137124

138125

139126
@pytest.mark.parametrize(
140-
"vocabulary, pca_dims, apply_zipf, sif_coefficient, expected_shape",
127+
"vocabulary, pca_dims, sif_coefficient, expected_shape",
141128
[
142-
(None, 256, True, None, (29524, 256)), # Output vocab with subwords, PCA applied
143-
(None, "auto", False, None, (29524, 768)), # Subword, PCA set to 'auto'
144-
(None, "auto", True, 1e-4, (29524, 768)), # Subword, PCA set to 'auto'
145-
(None, "auto", False, 1e-4, (29524, 768)), # Subword, PCA set to 'auto'
146-
(None, "auto", True, 0, None), # Sif too low
147-
(None, "auto", True, 1, None), # Sif too high
148-
(None, "auto", False, 0, (29524, 768)), # Sif too low, but apply_zipf is False
149-
(None, "auto", False, 1, (29524, 768)), # Sif too high, but apply_zipf is False
150-
(None, 1024, False, None, (29524, 768)), # Subword, PCA set to high number.
151-
(["wordA", "wordB"], 4, False, None, (29526, 4)), # Custom vocab with subword, PCA applied
152-
(None, None, True, None, (29524, 768)), # No PCA applied
129+
(None, 256, None, (29524, 256)), # PCA applied, SIF off
130+
(None, "auto", None, (29524, 768)), # PCA 'auto', SIF off
131+
(None, "auto", 1e-4, (29524, 768)), # PCA 'auto', SIF on
132+
(None, "auto", 0, None), # invalid SIF (too low) -> raises
133+
(None, "auto", 1, None), # invalid SIF (too high) -> raises
134+
(None, 1024, None, (29524, 768)), # PCA set high (no reduction)
135+
(["wordA", "wordB"], 4, None, (29526, 4)), # Custom vocab, PCA applied
136+
(None, None, None, (29524, 768)), # No PCA, SIF off
153137
],
154138
)
155139
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@@ -160,47 +144,32 @@ def test_distill(
160144
mock_transformer: PreTrainedModel,
161145
vocabulary: list[str] | None,
162146
pca_dims: int | None,
163-
apply_zipf: bool,
164147
sif_coefficient: float | None,
165-
expected_shape: tuple[int, int],
148+
expected_shape: tuple[int, int] | None,
166149
) -> None:
167150
"""Test distill function with different parameters."""
168-
# Mock the return value of model_info to avoid calling the Hugging Face API
169151
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
170-
171-
# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
172152
mock_auto_model.return_value = mock_transformer
173153

174154
model_name = "tests/data/test_tokenizer"
175155

176-
if (
177-
apply_zipf is not None
178-
and apply_zipf
179-
and sif_coefficient is not None
180-
and (sif_coefficient <= 0 or sif_coefficient >= 1)
181-
):
156+
if sif_coefficient is not None and (sif_coefficient <= 0 or sif_coefficient >= 1):
182157
with pytest.raises(ValueError):
183-
static_model = distill(
158+
_ = distill(
184159
model_name=model_name,
185160
vocabulary=vocabulary,
186161
device="cpu",
187162
pca_dims=pca_dims,
188-
apply_zipf=apply_zipf,
189163
sif_coefficient=sif_coefficient,
190164
)
191-
192165
else:
193-
# Call the distill function with the parametrized inputs
194166
static_model = distill(
195167
model_name=model_name,
196168
vocabulary=vocabulary,
197169
device="cpu",
198170
pca_dims=pca_dims,
199-
apply_zipf=apply_zipf,
200171
sif_coefficient=sif_coefficient,
201172
)
202-
203-
# Assert the model is correctly generated
204173
assert isinstance(static_model, StaticModel)
205174
assert static_model.embedding.shape == expected_shape
206175
assert "mock-model" in static_model.config["tokenizer_name"]
@@ -223,37 +192,36 @@ def test_missing_modelinfo(
223192
"embeddings, pca_dims, sif_coefficient, expected_shape",
224193
[
225194
(rng.random((1000, 768)), 256, None, (1000, 256)), # PCA applied correctly
226-
(rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions remain unchanged
227-
(rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and Zipf applied
228-
(rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims higher than vocab size, no PCA applied
195+
(rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions unchanged
196+
(rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and SIF applied
197+
(rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims > vocab size, no PCA applied
229198
],
230199
)
231200
def test__post_process_embeddings(
232-
embeddings: np.ndarray, pca_dims: int, sif_coefficient: float | None, expected_shape: tuple[int, int]
201+
embeddings: np.ndarray, pca_dims: int | float | None, sif_coefficient: float | None, expected_shape: tuple[int, int]
233202
) -> None:
234-
"""Test the _post_process_embeddings function."""
203+
"""Test the post_process_embeddings function."""
235204
original_embeddings = embeddings.copy() # Copy embeddings to compare later
236205

237-
# Test that the function raises an error if the PCA dims are larger than the number of dimensions
238-
if pca_dims and pca_dims > embeddings.shape[1]:
239-
with pytest.raises(ValueError):
240-
post_process_embeddings(embeddings, pca_dims, None)
206+
# If pca_dims > original dims and is an int, ensure function handles gracefully (warns, no exception)
207+
if isinstance(pca_dims, int) and pca_dims and pca_dims > embeddings.shape[1]:
208+
# The implementation logs a warning and skips reduction; no exception expected.
209+
pass
241210

242211
processed_embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient)
243212

244213
# Assert the shape is correct
245214
assert processed_embeddings.shape == expected_shape
246215

247-
# If Zipf weighting is applied compare the original and processed embeddings
248-
# and check the weights are applied correctly
216+
# If SIF weighting is applied and no PCA reduction, check weights are applied correctly
249217
if sif_coefficient and pca_dims is None:
250218
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
251219
proba = inv_rank / np.sum(inv_rank)
252220
sif_weights = (sif_coefficient / (sif_coefficient + proba))[:, None]
253221

254222
expected_zipf_embeddings = original_embeddings * sif_weights
255223
assert np.allclose(processed_embeddings, expected_zipf_embeddings, rtol=1e-5), (
256-
"Zipf weighting not applied correctly"
224+
"SIF weighting not applied correctly"
257225
)
258226

259227

@@ -275,7 +243,7 @@ def test_clean_and_create_vocabulary(
275243
expected_warnings: list[str],
276244
caplog: LogCaptureFixture,
277245
) -> None:
278-
"""Test the _clean_vocabulary function."""
246+
"""Test the clean_and_create_vocabulary helper."""
279247
with caplog.at_level("WARNING"):
280248
tokens, _ = clean_and_create_vocabulary(mock_berttokenizer, added_tokens, None)
281249

@@ -285,8 +253,6 @@ def test_clean_and_create_vocabulary(
285253

286254
# Check the warnings were logged as expected
287255
logged_warnings = [record.message for record in caplog.records]
288-
289-
# Ensure the expected warnings contain expected keywords like 'Removed', 'duplicate', or 'empty'
290256
for expected_warning in expected_warnings:
291257
assert any(expected_warning in logged_warning for logged_warning in logged_warnings)
292258

0 commit comments

Comments
 (0)