Skip to content

Commit 158720a

Browse files
committed
fix compression format handling
1 parent 22636cf commit 158720a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,15 @@ def from_pretrained_model(
177177
algorithm
178178
:return: compressor for the configs, or None if model is not compressed
179179
"""
180-
# assume multiple compression formats means mixed-precision
181-
# as we currently only support one compressor per precision type and scheme
180+
182181
if quantization_format is not None:
183-
if isinstance(quantization_format, str):
182+
# llmcompressor incorrectly passes in a CompressionFormat when
183+
# the value string is expected - handle both cases
184+
if isinstance(quantization_format, (str, CompressionFormat)):
184185
quantization_format = [quantization_format]
185186

187+
# assume multiple compression formats means mixed-precision
188+
# as we currently only support one compressor per precision type and scheme
186189
if len(quantization_format) > 1:
187190
quantization_format = CompressionFormat.mixed_precision.value
188191
else:
@@ -191,7 +194,6 @@ def from_pretrained_model(
191194
quantization_config = QuantizationConfig.from_pretrained(
192195
model, format=quantization_format
193196
)
194-
195197
if isinstance(sparsity_config, str): # we passed in a sparsity format
196198
sparsity_config = SparsityCompressionConfig.load_from_registry(
197199
sparsity_config

0 commit comments

Comments
 (0)