@@ -180,7 +180,7 @@ def from_pretrained_model(
180
180
# assume multiple compression formats means mixed-precision
181
181
# as we currently only support one compressor per precision type and scheme
182
182
if len (quantization_format ) > 1 :
183
- quantization_format = CompressionFormat .mixed_precision
183
+ quantization_format = CompressionFormat .mixed_precision . value
184
184
else :
185
185
quantization_format = quantization_format [0 ]
186
186
@@ -258,15 +258,15 @@ def parse_quantization_config(
258
258
259
259
return quantization_config
260
260
261
- def _fetch_unique_quantization_formats (self ):
261
+ def _fetch_unique_quantization_formats (self ) -> List [ str ] :
262
262
"""
263
- Get all unique compression formats used in
264
- model
263
+ Get all unique compression formats present in a model
264
+ :return: list of quantization formats
265
265
"""
266
266
quantization_formats = []
267
267
for _ , scheme in self .quantization_config .config_groups .items ():
268
268
if scheme .format not in quantization_formats :
269
- quantization_formats .append (scheme )
269
+ quantization_formats .append (scheme . format )
270
270
return quantization_formats
271
271
272
272
def __init__ (
@@ -287,7 +287,6 @@ def __init__(
287
287
)
288
288
289
289
quantization_formats = self ._fetch_unique_quantization_formats ()
290
-
291
290
if quantization_config is not None :
292
291
self .quantization_compressor = {}
293
292
for format in quantization_formats :
@@ -694,6 +693,7 @@ def update_config(self, save_directory: str):
694
693
config_data [QUANTIZATION_CONFIG_NAME ][
695
694
COMPRESSION_VERSION_NAME
696
695
] = compressed_tensors .__version__
696
+
697
697
if self .quantization_config is not None :
698
698
self .quantization_config .quant_method = DEFAULT_QUANTIZATION_METHOD
699
699
else :
0 commit comments