@@ -276,8 +276,12 @@ def _fetch_unique_quantization_formats(self) -> List[str]:
276
276
"""
277
277
quantization_formats = []
278
278
for _ , scheme in self .quantization_config .config_groups .items ():
279
- if scheme .format not in quantization_formats :
279
+ if scheme .format is not None and scheme . format not in quantization_formats :
280
280
quantization_formats .append (scheme .format )
281
+
282
+ # If empty list, fallback to using the global format
283
+ if len (quantization_formats ) == 0 :
284
+ quantization_formats .append (self .quantization_config .format )
281
285
return quantization_formats
282
286
283
287
def __init__ (
@@ -301,8 +305,8 @@ def __init__(
301
305
sparsity_config .format , config = sparsity_config
302
306
)
303
307
304
- quantization_formats = self ._fetch_unique_quantization_formats ()
305
308
if quantization_config is not None :
309
+ quantization_formats = self ._fetch_unique_quantization_formats ()
306
310
self .quantization_compressor = {}
307
311
for format in quantization_formats :
308
312
self .quantization_compressor [
@@ -567,7 +571,7 @@ def compress(
567
571
if self .quantization_compressor is not None :
568
572
module_to_scheme = map_module_to_scheme (model )
569
573
# Note - compress only supports one compression format atm
570
- quant_compressor = next (iter (self .quantization_compressor ))
574
+ quant_compressor = next (iter (self .quantization_compressor . values () ))
571
575
state_dict = quant_compressor .compress (
572
576
state_dict ,
573
577
names_to_scheme = module_to_scheme ,
@@ -623,7 +627,7 @@ def decompress(self, model_path: str, model: Module):
623
627
and self .sparsity_config .format != CompressionFormat .dense .value
624
628
):
625
629
# note - decompress only supports one compressor atm
626
- quant_compressor = next (iter (self .quantization_compressor ))
630
+ quant_compressor = next (iter (self .quantization_compressor . values () ))
627
631
params_to_ignore = None
628
632
if self .quantization_compressor is not None :
629
633
params_to_ignore = quant_compressor .compression_param_names
0 commit comments