@@ -15,6 +15,14 @@ class DType(str, Enum):
1515 Int8 = "int8"
1616
1717
18+ dtype_map = {
19+ DType .Float16 : np .float16 ,
20+ DType .Float32 : np .float32 ,
21+ DType .Float64 : np .float64 ,
22+ DType .Int8 : np .int8 ,
23+ }
24+
25+
1826def quantize_embeddings (embeddings : np .ndarray , quantize_to : DType ) -> np .ndarray :
1927 """
2028 Quantize embeddings to a specified data type to reduce memory usage.
@@ -24,17 +32,28 @@ def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarra
2432 :return: The quantized embeddings.
2533 :raises ValueError: If the quantization type is not valid.
2634 """
27- if quantize_to == DType .Float16 :
28- return embeddings .astype (np .float16 )
29- elif quantize_to == DType .Float32 :
30- return embeddings .astype (np .float32 )
31- elif quantize_to == DType .Float64 :
32- return embeddings .astype (np .float64 )
35+ mapped_dtype = dtype_map [quantize_to ]
36+ if embeddings .dtype == mapped_dtype :
37+ # Don't do anything if they match
38+ return embeddings
39+
40+ # Handle float types
41+ if quantize_to in {DType .Float16 , DType .Float32 , DType .Float64 }:
42+ return embeddings .astype (mapped_dtype )
3343 elif quantize_to == DType .Int8 :
3444 # Normalize to [-128, 127] range for int8
3545 # We normalize to -127 to 127 to keep symmetry.
3646 scale = np .max (np .abs (embeddings )) / 127.0
37- quantized = np .round (embeddings / scale ).astype (np .int8 )
47+ # Turn into float16 to minimize memory usage during computation
48+ # we copy once.
49+ buf = embeddings .astype (np .float16 , copy = True )
50+ # Divide by the scale
51+ np .divide (buf , scale , out = buf )
52+ # Round to int, copy to the buffer
53+ np .rint (buf , out = buf )
54+ # Clip to int8 range and convert to int8
55+ np .clip (buf , - 127 , 127 , out = buf )
56+ quantized = buf .astype (np .int8 )
3857 return quantized
3958 else :
4059 raise ValueError ("Not a valid enum member of DType." )
0 commit comments