@@ -185,7 +185,7 @@ def from_pretrained_model(
185
185
# assume multiple compression formats means mixed-precision
186
186
# as we currently only support one compressor per precision type and scheme
187
187
if len (quantization_format ) > 1 :
188
- quantization_format = CompressionFormat .mixed_precision
188
+ quantization_format = CompressionFormat .mixed_precision . value
189
189
else :
190
190
quantization_format = quantization_format [0 ]
191
191
@@ -269,15 +269,15 @@ def parse_quantization_config(
269
269
270
270
return quantization_config
271
271
272
- def _fetch_unique_quantization_formats (self ):
272
+ def _fetch_unique_quantization_formats (self ) -> List [ str ] :
273
273
"""
274
- Get all unique compression formats used in
275
- model
274
+ Get all unique compression formats present in a model
275
+ :return: list of quantization formats
276
276
"""
277
277
quantization_formats = []
278
278
for _ , scheme in self .quantization_config .config_groups .items ():
279
279
if scheme .format not in quantization_formats :
280
- quantization_formats .append (scheme )
280
+ quantization_formats .append (scheme . format )
281
281
return quantization_formats
282
282
283
283
def __init__ (
@@ -302,7 +302,6 @@ def __init__(
302
302
)
303
303
304
304
quantization_formats = self ._fetch_unique_quantization_formats ()
305
-
306
305
if quantization_config is not None :
307
306
self .quantization_compressor = {}
308
307
for format in quantization_formats :
0 commit comments