Skip to content

Commit 49fc262

Browse files
committed
Added properties, updated config saving, updated tests
1 parent fb3b2f7 commit 49fc262

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

model2vec/hf_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,18 @@ def save_pretrained(
5454
save_file(model_weights, folder_path / "model.safetensors")
5555
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
5656

57-
# Add embedding dtype to config
58-
config["embedding_dtype"] = np.dtype(embeddings.dtype).name
59-
json.dump(config, open(folder_path / "config.json", "w"), indent=4)
57+
# Create a copy of config and add dtype and vocab quantization
58+
cfg = dict(config)
59+
cfg["embedding_dtype"] = np.dtype(embeddings.dtype).name
60+
if mapping is not None:
61+
cfg["vocabulary_quantization"] = int(embeddings.shape[0])
62+
else:
63+
cfg.pop("vocabulary_quantization", None)
64+
json.dump(cfg, open(folder_path / "config.json", "w"), indent=4)
6065

6166
# Create modules.json
6267
modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}]
63-
if config.get("normalize"):
68+
if cfg.get("normalize"):
6469
# If normalize=True, add sentence_transformers.models.Normalize
6570
modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"})
6671
json.dump(modules, open(folder_path / "modules.json", "w"), indent=4)

model2vec/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@ def normalize(self, value: bool) -> None:
111111
)
112112
self.config["normalize"] = value
113113

114+
@property
115+
def embedding_dtype(self) -> str:
116+
"""Get the dtype (precision) of the embedding matrix."""
117+
return np.dtype(self.embedding.dtype).name
118+
119+
@property
120+
def vocabulary_quantization(self) -> int | None:
121+
"""Get the number of clusters used for vocabulary quantization, if applicable."""
122+
is_quantized = (self.token_mapping is not None) or (len(self.embedding) != len(self.tokens))
123+
return int(self.embedding.shape[0]) if is_quantized else None
124+
114125
def save_pretrained(self, path: PathLike, model_name: str | None = None, subfolder: str | None = None) -> None:
115126
"""
116127
Save the pretrained model.
@@ -493,8 +504,6 @@ def quantize_model(
493504
:return: A new StaticModel with the quantized embeddings.
494505
:raises: ValueError if the model is already quantized.
495506
"""
496-
from model2vec.quantization import quantize_and_reduce_dim
497-
498507
token_mapping: np.ndarray | None
499508
weights: np.ndarray | None
500509
if vocabulary_quantization is not None:
@@ -506,7 +515,6 @@ def quantize_model(
506515
embeddings, token_mapping, weights = quantize_vocabulary(
507516
n_clusters=vocabulary_quantization, weights=model.weights, embeddings=model.embedding
508517
)
509-
model.config["vocabulary_quantization"] = vocabulary_quantization
510518
else:
511519
embeddings = model.embedding
512520
token_mapping = model.token_mapping
@@ -521,7 +529,7 @@ def quantize_model(
521529
return StaticModel(
522530
vectors=embeddings,
523531
tokenizer=model.tokenizer,
524-
config=model.config,
532+
config=dict(model.config),
525533
weights=weights,
526534
token_mapping=token_mapping,
527535
normalize=model.normalize,

tests/test_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def test_load_pretrained(
180180
# Assert that the loaded model has the same properties as the original one
181181
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors)
182182
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
183-
assert loaded_model.config == mock_config
183+
for k, v in mock_config.items():
184+
assert loaded_model.config.get(k) == v
185+
assert "embedding_dtype" in loaded_model.config
184186

185187

186188
def test_load_pretrained_quantized(
@@ -198,19 +200,22 @@ def test_load_pretrained_quantized(
198200
# Assert that the loaded model has the same properties as the original one
199201
assert loaded_model.embedding.dtype == np.int8
200202
assert loaded_model.embedding.shape == mock_vectors.shape
203+
assert loaded_model.embedding_dtype == "int8"
201204

202205
# Load the model back from the same path
203206
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float16")
204207

205208
# Assert that the loaded model has the same properties as the original one
206209
assert loaded_model.embedding.dtype == np.float16
207210
assert loaded_model.embedding.shape == mock_vectors.shape
211+
assert loaded_model.embedding_dtype == "float16"
208212

209213
# Load the model back from the same path
210214
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float32")
211215
# Assert that the loaded model has the same properties as the original one
212216
assert loaded_model.embedding.dtype == np.float32
213217
assert loaded_model.embedding.shape == mock_vectors.shape
218+
assert loaded_model.embedding_dtype == "float32"
214219

215220
# Load the model back from the same path
216221
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float64")
@@ -234,15 +239,19 @@ def test_load_pretrained_dim(
234239
# Assert that the loaded model has the same properties as the original one
235240
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors[:, :2])
236241
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
237-
assert loaded_model.config == mock_config
242+
for k, v in mock_config.items():
243+
assert loaded_model.config.get(k) == v
244+
assert "embedding_dtype" in loaded_model.config
238245

239246
# Load the model back from the same path
240247
loaded_model = StaticModel.from_pretrained(save_path, dimensionality=None)
241248

242249
# Assert that the loaded model has the same properties as the original one
243250
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors)
244251
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
245-
assert loaded_model.config == mock_config
252+
for k, v in mock_config.items():
253+
assert loaded_model.config.get(k) == v
254+
assert "embedding_dtype" in loaded_model.config
246255

247256
# Load the model back from the same path
248257
with pytest.raises(ValueError):
@@ -267,6 +276,7 @@ def test_load_pretrained_vocabulary_quantized(
267276
assert loaded_model.weights is not None
268277
assert loaded_model.weights.shape == (5,)
269278
assert loaded_model.token_mapping is not None
279+
assert loaded_model.vocabulary_quantization == 3
270280
assert len(loaded_model.token_mapping) == mock_tokenizer.get_vocab_size()
271281
assert len(loaded_model.token_mapping) == len(loaded_model.weights)
272282
assert loaded_model.encode("word1 word2").shape == (2,)

0 commit comments

Comments
 (0)