Skip to content

Commit da79779

Browse files
authored
feat: Add SIF-like coef (#174)
* feat: token mean and sif weighting * fix error * rename * fix tests * feat: update config * reviewer comments * set normalize to True by default * fix tests * fix: bug and tests
1 parent af8ba05 commit da79779

File tree

2 files changed

+138
-46
lines changed

2 files changed

+138
-46
lines changed

model2vec/distill/distillation.py

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def distill_from_model(
4040
vocabulary: list[str] | None = None,
4141
device: str | None = None,
4242
pca_dims: PCADimType = 256,
43-
apply_zipf: bool = True,
43+
apply_zipf: bool | None = None,
44+
sif_coefficient: float | None = 1e-4,
4445
use_subword: bool = True,
4546
token_remove_pattern: str | None = r"\[unused\d+\]",
4647
) -> StaticModel:
@@ -60,30 +61,19 @@ def distill_from_model(
6061
:param pca_dims: The number of components to use for PCA.
6162
If this is None, we don't apply PCA.
6263
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
63-
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
64+
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
65+
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
66+
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
67+
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
6468
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
6569
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
6670
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
67-
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
68-
:raises: ValueError if the vocabulary contains duplicate tokens.
69-
:raises: ValueError if the regex can't be compiled.
70-
:raises: ValueError if the vocabulary is empty after token removal.
7171
:return: A StaticModel
7272
7373
"""
74-
device = select_optimal_device(device)
75-
if not use_subword and vocabulary is None:
76-
raise ValueError(
77-
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
78-
)
79-
80-
if vocabulary and isinstance(tokenizer.backend_tokenizer.model, (BPE, Unigram)):
81-
raise ValueError(
82-
"You passed a vocabulary, but the model you are using does not use a WordPiece tokenizer. "
83-
"This is not supported yet."
84-
"Feel free to open an issue if this is a blocker: https://github.com/MinishLab/model2vec/issues"
85-
)
74+
sif_coefficient = _validate_parameters(tokenizer, vocabulary, apply_zipf, sif_coefficient, use_subword)
8675

76+
device = select_optimal_device(device)
8777
# Make a base list of tokens.
8878
tokens: list[str] = []
8979
if use_subword:
@@ -129,7 +119,7 @@ def distill_from_model(
129119
logger.warning("Didn't create any token embeddings as all tokens were duplicates or empty.")
130120

131121
# Post process the embeddings by applying PCA and Zipf weighting.
132-
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, apply_zipf)
122+
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
133123

134124
model_name = getattr(model, "name_or_path", "")
135125

@@ -139,8 +129,10 @@ def distill_from_model(
139129
"tokenizer_name": model_name,
140130
"apply_pca": pca_dims,
141131
"apply_zipf": apply_zipf,
132+
"sif_coefficient": sif_coefficient,
142133
"hidden_dim": embeddings.shape[1],
143134
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
135+
"normalize": True,
144136
}
145137

146138
if os.path.exists(model_name):
@@ -157,10 +149,71 @@ def distill_from_model(
157149
language = None
158150

159151
return StaticModel(
160-
vectors=embeddings, tokenizer=new_tokenizer, config=config, base_model_name=model_name, language=language
152+
vectors=embeddings,
153+
tokenizer=new_tokenizer,
154+
config=config,
155+
base_model_name=model_name,
156+
language=language,
157+
normalize=True,
161158
)
162159

163160

161+
def _validate_parameters(
162+
tokenizer: PreTrainedTokenizerFast,
163+
vocabulary: list[str] | None,
164+
apply_zipf: bool | None,
165+
sif_coefficient: float | None,
166+
use_subword: bool,
167+
) -> float | None:
168+
"""
169+
Validate the parameters passed to the distillation function.
170+
171+
:param tokenizer: The tokenizer to use.
172+
:param vocabulary: The vocabulary to use.
173+
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
174+
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
175+
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
176+
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
177+
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
178+
:return: The SIF coefficient to use.
179+
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
180+
:raises: ValueError if the vocabulary contains duplicate tokens.
181+
:raises: ValueError if the regex can't be compiled.
182+
:raises: ValueError if the vocabulary is empty after token removal.
183+
184+
"""
185+
if apply_zipf is not None:
186+
logger.warning(
187+
"The `apply_zipf` parameter is deprecated and will be removed in the next release. "
188+
"Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
189+
"no weighting is applied."
190+
)
191+
if apply_zipf and sif_coefficient is None:
192+
logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
193+
sif_coefficient = 1e-4
194+
elif not apply_zipf:
195+
logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
196+
sif_coefficient = None
197+
198+
if sif_coefficient is not None:
199+
if not 0 < sif_coefficient < 1.0:
200+
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")
201+
202+
if not use_subword and vocabulary is None:
203+
raise ValueError(
204+
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
205+
)
206+
207+
if vocabulary and isinstance(tokenizer.backend_tokenizer.model, (BPE, Unigram)):
208+
raise ValueError(
209+
"You passed a vocabulary, but the model you are using does not use a WordPiece tokenizer. "
210+
"This is not supported yet."
211+
"Feel free to open an issue if this is a blocker: https://github.com/MinishLab/model2vec/issues"
212+
)
213+
214+
return sif_coefficient
215+
216+
164217
def _remove_tokens_and_embeddings(
165218
tokenizer: PreTrainedTokenizerFast, token_remove_pattern: str | None, tokens: list[str], embeddings: np.ndarray
166219
) -> tuple[Tokenizer, np.ndarray]:
@@ -201,7 +254,8 @@ def distill(
201254
vocabulary: list[str] | None = None,
202255
device: str | None = None,
203256
pca_dims: PCADimType = 256,
204-
apply_zipf: bool = True,
257+
apply_zipf: bool | None = None,
258+
sif_coefficient: float | None = 1e-4,
205259
use_subword: bool = True,
206260
token_remove_pattern: str | None = r"\[unused\d+\]",
207261
trust_remote_code: bool = False,
@@ -221,7 +275,10 @@ def distill(
221275
:param pca_dims: The number of components to use for PCA.
222276
If this is None, we don't apply PCA.
223277
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
224-
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
278+
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
279+
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
280+
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
281+
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
225282
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
226283
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
227284
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
@@ -240,11 +297,14 @@ def distill(
240297
apply_zipf=apply_zipf,
241298
use_subword=use_subword,
242299
token_remove_pattern=token_remove_pattern,
300+
sif_coefficient=sif_coefficient,
243301
)
244302

245303

246-
def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply_zipf: bool) -> np.ndarray:
247-
"""Post process embeddings by applying PCA and Zipf weighting."""
304+
def _post_process_embeddings(
305+
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
306+
) -> np.ndarray:
307+
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
248308
if pca_dims is not None:
249309
if pca_dims == "auto":
250310
pca_dims = embeddings.shape[1]
@@ -276,9 +336,11 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply
276336
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
277337
logger.info(f"Explained variance: {explained_variance:.3f}.")
278338

279-
if apply_zipf:
280-
logger.info("Applying Zipf weighting")
281-
embeddings *= np.log(1 + np.arange(embeddings.shape[0]))[:, None]
339+
if sif_coefficient is not None:
340+
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
341+
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
342+
proba = inv_rank / np.sum(inv_rank)
343+
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
282344

283345
return embeddings
284346

tests/test_distillation.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,29 @@ def test_distill_removal_pattern(
160160

161161

162162
@pytest.mark.parametrize(
163-
"vocabulary, use_subword, pca_dims, apply_zipf, expected_shape",
163+
"vocabulary, use_subword, pca_dims, apply_zipf, sif_coefficient, expected_shape",
164164
[
165-
(None, True, 256, True, (29528, 256)), # Output vocab with subwords, PCA applied
165+
(None, True, 256, True, None, (29528, 256)), # Output vocab with subwords, PCA applied
166166
(
167167
["wordA", "wordB"],
168168
False,
169169
4,
170170
False,
171+
None,
171172
(7, 4),
172173
), # Custom vocab without subword , PCA applied
173-
(None, True, "auto", False, (29528, 768)), # Subword, PCA set to 'auto'
174-
(None, True, 1024, False, (29528, 768)), # Subword, PCA set to high number.
175-
(["wordA", "wordB"], True, 4, False, (29530, 4)), # Custom vocab with subword, PCA applied
176-
(None, True, None, True, (29528, 768)), # No PCA applied
177-
(["wordA", "wordB"], False, 4, True, (7, 4)), # Custom vocab without subwords PCA and Zipf applied
178-
(None, False, 256, True, None), # use_subword = False without passing a vocabulary should raise an error
174+
(None, True, "auto", False, None, (29528, 768)), # Subword, PCA set to 'auto'
175+
(None, True, "auto", True, 1e-4, (29528, 768)), # Subword, PCA set to 'auto'
176+
(None, True, "auto", False, 1e-4, (29528, 768)), # Subword, PCA set to 'auto'
177+
(None, True, "auto", True, 0, None), # Sif too low
178+
(None, True, "auto", True, 1, None), # Sif too high
179+
(None, True, "auto", False, 0, (29528, 768)), # Sif too low, but apply_zipf is False
180+
(None, True, "auto", False, 1, (29528, 768)), # Sif too high, but apply_zipf is False
181+
(None, True, 1024, False, None, (29528, 768)), # Subword, PCA set to high number.
182+
(["wordA", "wordB"], True, 4, False, None, (29530, 4)), # Custom vocab with subword, PCA applied
183+
(None, True, None, True, None, (29528, 768)), # No PCA applied
184+
(["wordA", "wordB"], False, 4, True, None, (7, 4)), # Custom vocab without subwords PCA and Zipf applied
185+
(None, False, 256, True, None, None), # use_subword = False without passing a vocabulary should raise an error
179186
],
180187
)
181188
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@@ -188,6 +195,7 @@ def test_distill(
188195
use_subword: bool,
189196
pca_dims: int | None,
190197
apply_zipf: bool,
198+
sif_coefficient: float | None,
191199
expected_shape: tuple[int, int],
192200
) -> None:
193201
"""Test distill function with different parameters."""
@@ -208,7 +216,25 @@ def test_distill(
208216
pca_dims=pca_dims,
209217
apply_zipf=apply_zipf,
210218
use_subword=use_subword,
219+
sif_coefficient=sif_coefficient,
211220
)
221+
elif (
222+
apply_zipf is not None
223+
and apply_zipf
224+
and sif_coefficient is not None
225+
and (sif_coefficient <= 0 or sif_coefficient >= 1)
226+
):
227+
with pytest.raises(ValueError):
228+
static_model = distill(
229+
model_name=model_name,
230+
vocabulary=vocabulary,
231+
device="cpu",
232+
pca_dims=pca_dims,
233+
apply_zipf=apply_zipf,
234+
use_subword=use_subword,
235+
sif_coefficient=sif_coefficient,
236+
)
237+
212238
else:
213239
# Call the distill function with the parametrized inputs
214240
static_model = distill(
@@ -218,6 +244,7 @@ def test_distill(
218244
pca_dims=pca_dims,
219245
apply_zipf=apply_zipf,
220246
use_subword=use_subword,
247+
sif_coefficient=sif_coefficient,
221248
)
222249

223250
# Assert the model is correctly generated
@@ -240,35 +267,38 @@ def test_missing_modelinfo(
240267

241268

242269
@pytest.mark.parametrize(
243-
"embeddings, pca_dims, apply_zipf, expected_shape",
270+
"embeddings, pca_dims, sif_coefficient, expected_shape",
244271
[
245-
(rng.random((1000, 768)), 256, False, (1000, 256)), # PCA applied correctly
246-
(rng.random((1000, 768)), None, False, (1000, 768)), # No PCA applied, dimensions remain unchanged
247-
(rng.random((1000, 768)), 256, True, (1000, 256)), # PCA and Zipf applied
248-
(rng.random((10, 768)), 256, False, (10, 768)), # PCA dims higher than vocab size, no PCA applied
272+
(rng.random((1000, 768)), 256, None, (1000, 256)), # PCA applied correctly
273+
(rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions remain unchanged
274+
(rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and Zipf applied
275+
(rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims higher than vocab size, no PCA applied
249276
],
250277
)
251278
def test__post_process_embeddings(
252-
embeddings: np.ndarray, pca_dims: int, apply_zipf: bool, expected_shape: tuple[int, int]
279+
embeddings: np.ndarray, pca_dims: int, sif_coefficient: float | None, expected_shape: tuple[int, int]
253280
) -> None:
254281
"""Test the _post_process_embeddings function."""
255282
original_embeddings = embeddings.copy() # Copy embeddings to compare later
256283

257284
# Test that the function raises an error if the PCA dims are larger than the number of dimensions
258285
if pca_dims and pca_dims > embeddings.shape[1]:
259286
with pytest.raises(ValueError):
260-
_post_process_embeddings(embeddings, pca_dims, False)
287+
_post_process_embeddings(embeddings, pca_dims, None)
261288

262-
processed_embeddings = _post_process_embeddings(embeddings, pca_dims, apply_zipf)
289+
processed_embeddings = _post_process_embeddings(embeddings, pca_dims, sif_coefficient)
263290

264291
# Assert the shape is correct
265292
assert processed_embeddings.shape == expected_shape
266293

267294
# If Zipf weighting is applied compare the original and processed embeddings
268295
# and check the weights are applied correctly
269-
if apply_zipf and pca_dims is None:
270-
zipf_weights = np.log(1 + np.arange(embeddings.shape[0]))[:, None]
271-
expected_zipf_embeddings = original_embeddings * zipf_weights
296+
if sif_coefficient and pca_dims is None:
297+
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
298+
proba = inv_rank / np.sum(inv_rank)
299+
sif_weights = (sif_coefficient / (sif_coefficient + proba))[:, None]
300+
301+
expected_zipf_embeddings = original_embeddings * sif_weights
272302
assert np.allclose(
273303
processed_embeddings, expected_zipf_embeddings, rtol=1e-5
274304
), "Zipf weighting not applied correctly"

0 commit comments

Comments
 (0)