Skip to content

Commit b201266

Browse files
committed
fix
1 parent c6136b2 commit b201266

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,12 @@ def _fetch_unique_quantization_formats(self) -> List[str]:
276276
"""
277277
quantization_formats = []
278278
for _, scheme in self.quantization_config.config_groups.items():
279-
if scheme.format not in quantization_formats:
279+
if scheme.format is not None and scheme.format not in quantization_formats:
280280
quantization_formats.append(scheme.format)
281+
282+
# If empty list, fallback to using the global format
283+
if len(quantization_formats) == 0:
284+
quantization_formats.append(self.quantization_config.format)
281285
return quantization_formats
282286

283287
def __init__(
@@ -301,8 +305,8 @@ def __init__(
301305
sparsity_config.format, config=sparsity_config
302306
)
303307

304-
quantization_formats = self._fetch_unique_quantization_formats()
305308
if quantization_config is not None:
309+
quantization_formats = self._fetch_unique_quantization_formats()
306310
self.quantization_compressor = {}
307311
for format in quantization_formats:
308312
self.quantization_compressor[
@@ -567,7 +571,7 @@ def compress(
567571
if self.quantization_compressor is not None:
568572
module_to_scheme = map_module_to_scheme(model)
569573
# Note - compress only supports one compression format atm
570-
quant_compressor = next(iter(self.quantization_compressor))
574+
quant_compressor = next(iter(self.quantization_compressor.values()))
571575
state_dict = quant_compressor.compress(
572576
state_dict,
573577
names_to_scheme=module_to_scheme,
@@ -623,7 +627,7 @@ def decompress(self, model_path: str, model: Module):
623627
and self.sparsity_config.format != CompressionFormat.dense.value
624628
):
625629
# note - decompress only supports one compressor atm
626-
quant_compressor = next(iter(self.quantization_compressor))
630+
quant_compressor = next(iter(self.quantization_compressor.values()))
627631
params_to_ignore = None
628632
if self.quantization_compressor is not None:
629633
params_to_ignore = quant_compressor.compression_param_names

0 commit comments

Comments
 (0)