@@ -265,8 +265,12 @@ def _fetch_unique_quantization_formats(self) -> List[str]:
265
265
"""
266
266
quantization_formats = []
267
267
for _ , scheme in self .quantization_config .config_groups .items ():
268
- if scheme .format not in quantization_formats :
268
+ if scheme .format is not None and scheme . format not in quantization_formats :
269
269
quantization_formats .append (scheme .format )
270
+
271
+ # If empty list, fallback to using the global format
272
+ if len (quantization_formats ) == 0 :
273
+ quantization_formats .append (self .quantization_config .format )
270
274
return quantization_formats
271
275
272
276
def __init__ (
@@ -286,8 +290,8 @@ def __init__(
286
290
sparsity_config .format , config = sparsity_config
287
291
)
288
292
289
- quantization_formats = self ._fetch_unique_quantization_formats ()
290
293
if quantization_config is not None :
294
+ quantization_formats = self ._fetch_unique_quantization_formats ()
291
295
self .quantization_compressor = {}
292
296
for format in quantization_formats :
293
297
self .quantization_compressor [
@@ -552,7 +556,7 @@ def compress(
552
556
if self .quantization_compressor is not None :
553
557
module_to_scheme = map_module_to_scheme (model )
554
558
# Note - compress only supports one compression format atm
555
- quant_compressor = next (iter (self .quantization_compressor ))
559
+ quant_compressor = next (iter (self .quantization_compressor . values () ))
556
560
state_dict = quant_compressor .compress (
557
561
state_dict ,
558
562
names_to_scheme = module_to_scheme ,
@@ -608,7 +612,7 @@ def decompress(self, model_path: str, model: Module):
608
612
and self .sparsity_config .format != CompressionFormat .dense .value
609
613
):
610
614
# note - decompress only supports one compressor atm
611
- quant_compressor = next (iter (self .quantization_compressor ))
615
+ quant_compressor = next (iter (self .quantization_compressor . values () ))
612
616
params_to_ignore = None
613
617
if self .quantization_compressor is not None :
614
618
params_to_ignore = quant_compressor .compression_param_names
0 commit comments