Skip to content

Commit 5a8578d

Browse files
authored
fix: load faster, make quantization better (#279)
* fix: load faster, make quantization better * tests: add test
1 parent 7bf0bf0 commit 5a8578d

File tree

4 files changed

+148
-117
lines changed

4 files changed

+148
-117
lines changed

model2vec/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ def _loading_helper(
566566
language=metadata.get("language"),
567567
)
568568

569+
# If no quantization or dimensionality reduction is requested,
570+
# return the model as is.
571+
if not any([vocabulary_quantization, quantize_to, dimensionality]):
572+
return model
573+
569574
return quantize_model(
570575
model=model,
571576
vocabulary_quantization=vocabulary_quantization,

model2vec/quantization.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ class DType(str, Enum):
1515
Int8 = "int8"
1616

1717

18+
dtype_map = {
19+
DType.Float16: np.float16,
20+
DType.Float32: np.float32,
21+
DType.Float64: np.float64,
22+
DType.Int8: np.int8,
23+
}
24+
25+
1826
def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
1927
"""
2028
Quantize embeddings to a specified data type to reduce memory usage.
@@ -24,17 +32,28 @@ def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarra
2432
:return: The quantized embeddings.
2533
:raises ValueError: If the quantization type is not valid.
2634
"""
27-
if quantize_to == DType.Float16:
28-
return embeddings.astype(np.float16)
29-
elif quantize_to == DType.Float32:
30-
return embeddings.astype(np.float32)
31-
elif quantize_to == DType.Float64:
32-
return embeddings.astype(np.float64)
35+
mapped_dtype = dtype_map[quantize_to]
36+
if embeddings.dtype == mapped_dtype:
37+
# Don't do anything if they match
38+
return embeddings
39+
40+
# Handle float types
41+
if quantize_to in {DType.Float16, DType.Float32, DType.Float64}:
42+
return embeddings.astype(mapped_dtype)
3343
elif quantize_to == DType.Int8:
3444
# Normalize to [-128, 127] range for int8
3545
# We normalize to -127 to 127 to keep symmetry.
3646
scale = np.max(np.abs(embeddings)) / 127.0
37-
quantized = np.round(embeddings / scale).astype(np.int8)
47+
# Turn into float16 to minimize memory usage during computation
48+
# we copy once.
49+
buf = embeddings.astype(np.float16, copy=True)
50+
# Divide by the scale
51+
np.divide(buf, scale, out=buf)
52+
# Round to int, copy to the buffer
53+
np.rint(buf, out=buf)
54+
# Clip to int8 range and convert to int8
55+
np.clip(buf, -127, 127, out=buf)
56+
quantized = buf.astype(np.int8)
3857
return quantized
3958
else:
4059
raise ValueError("Not a valid enum member of DType.")

tests/test_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,13 @@ def test_load_pretrained_quantized(
212212
assert loaded_model.embedding.dtype == np.float32
213213
assert loaded_model.embedding.shape == mock_vectors.shape
214214

215+
# Load the model back from the same path
216+
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float64")
217+
# Assert that the loaded model has the same properties as the original one
218+
assert loaded_model.embedding.dtype == np.float64
219+
# Should not copy if same as original.
220+
assert loaded_model.embedding is loaded_model.embedding
221+
215222

216223
def test_load_pretrained_dim(
217224
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]

0 commit comments

Comments
 (0)