Skip to content

Commit cb1c427

Browse files
committed
fix
1 parent 8f514c4 commit cb1c427

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,12 @@ def _fetch_unique_quantization_formats(self) -> List[str]:
265265
"""
266266
quantization_formats = []
267267
for _, scheme in self.quantization_config.config_groups.items():
268-
if scheme.format not in quantization_formats:
268+
if scheme.format is not None and scheme.format not in quantization_formats:
269269
quantization_formats.append(scheme.format)
270+
271+
# If empty list, fallback to using the global format
272+
if len(quantization_formats) == 0:
273+
quantization_formats.append(self.quantization_config.format)
270274
return quantization_formats
271275

272276
def __init__(
@@ -286,8 +290,8 @@ def __init__(
286290
sparsity_config.format, config=sparsity_config
287291
)
288292

289-
quantization_formats = self._fetch_unique_quantization_formats()
290293
if quantization_config is not None:
294+
quantization_formats = self._fetch_unique_quantization_formats()
291295
self.quantization_compressor = {}
292296
for format in quantization_formats:
293297
self.quantization_compressor[
@@ -552,7 +556,7 @@ def compress(
552556
if self.quantization_compressor is not None:
553557
module_to_scheme = map_module_to_scheme(model)
554558
# Note - compress only supports one compression format atm
555-
quant_compressor = next(iter(self.quantization_compressor))
559+
quant_compressor = next(iter(self.quantization_compressor.values()))
556560
state_dict = quant_compressor.compress(
557561
state_dict,
558562
names_to_scheme=module_to_scheme,
@@ -608,7 +612,7 @@ def decompress(self, model_path: str, model: Module):
608612
and self.sparsity_config.format != CompressionFormat.dense.value
609613
):
610614
# note - decompress only supports one compressor atm
611-
quant_compressor = next(iter(self.quantization_compressor))
615+
quant_compressor = next(iter(self.quantization_compressor.values()))
612616
params_to_ignore = None
613617
if self.quantization_compressor is not None:
614618
params_to_ignore = quant_compressor.compression_param_names

0 commit comments

Comments
 (0)