Skip to content

Commit 6bf7171

Browse files
committed
handle mixed-precision case
1 parent b201266 commit 6bf7171

File tree

2 files changed

+60
-21
lines changed

2 files changed

+60
-21
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def from_pretrained_model(
169169
cls,
170170
model: Module,
171171
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
172-
quantization_format: Optional[List[str]] = None,
172+
quantization_format: Optional[Union[str, List[str]]] = None,
173173
) -> Optional["ModelCompressor"]:
174174
"""
175175
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -184,10 +184,14 @@ def from_pretrained_model(
184184
"""
185185
# assume multiple compression formats means mixed-precision
186186
# as we currently only support one compressor per precision type and scheme
187-
if len(quantization_format) > 1:
188-
quantization_format = CompressionFormat.mixed_precision.value
189-
else:
190-
quantization_format = quantization_format[0]
187+
if quantization_format is not None:
188+
if isinstance(quantization_format, str):
189+
quantization_format = [quantization_format]
190+
191+
if len(quantization_format) > 1:
192+
quantization_format = CompressionFormat.mixed_precision.value
193+
else:
194+
quantization_format = quantization_format[0]
191195

192196
quantization_config = QuantizationConfig.from_pretrained(
193197
model, format=quantization_format
@@ -408,12 +412,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
408412
targets=scheme.targets,
409413
ignore=self.quantization_config.ignore,
410414
)
411-
unexpected_keys.update(
412-
merge_names(target, param)
413-
for target in quant_targets
414-
for param in self.quantization_compressor.compression_param_names
415-
if param != "weight"
416-
)
415+
for quant_compressor in self.quantization_compressor.values():
416+
unexpected_keys.update(
417+
merge_names(target, param)
418+
for target in quant_targets
419+
for param in quant_compressor.compression_param_names
420+
if param != "weight"
421+
)
417422

418423
return list(unexpected_keys)
419424

@@ -451,9 +456,24 @@ def compress_model(self, model: Module):
451456

452457
# quantization first
453458
if prefix in module_to_scheme:
454-
quant_compressor = self.quantization_compressor.get(
455-
module.quantization_scheme.format
456-
)
459+
if (
460+
not hasattr(module.quantization_scheme, "format")
461+
or module.quantization_scheme.format is None
462+
):
463+
if (
464+
self.quantization_config.format
465+
== CompressionFormat.mixed_precision.value
466+
):
467+
raise ValueError(
468+
"Compressing mixed-precision models without defining "
469+
"per module quantization_scheme.format is currently "
470+
"not supported"
471+
)
472+
format = self.quantization_config.format
473+
else:
474+
format = module.quantization_scheme.format
475+
476+
quant_compressor = self.quantization_compressor.get(format)
457477
state_dict = quant_compressor.compress(
458478
state_dict,
459479
names_to_scheme=module_to_scheme,
@@ -525,9 +545,24 @@ def decompress_model(self, model: Module):
525545

526546
# quantization second
527547
if prefix in module_to_scheme:
528-
quant_compressor = self.quantization_compressor.get(
529-
module.quantization_scheme.format
530-
)
548+
549+
if (
550+
not hasattr(module.quantization_scheme, "format")
551+
or module.quantization_scheme.format is None
552+
):
553+
if (
554+
self.quantization_config.format
555+
== CompressionFormat.mixed_precision.value
556+
):
557+
raise ValueError(
558+
"Decompressing mixed-precision models without defining "
559+
"per module quantization_scheme.format is currently not "
560+
"supported"
561+
)
562+
format = self.quantization_config.format
563+
else:
564+
format = module.quantization_scheme.format
565+
quant_compressor = self.quantization_compressor.get(format)
531566
state_dict = quant_compressor.decompress_module_from_state_dict(
532567
prefix,
533568
state_dict,
@@ -621,15 +656,19 @@ def decompress(self, model_path: str, model: Module):
621656
"""
622657
model_path = get_safetensors_folder(model_path)
623658
sparse_decompressed = False
659+
quant_compressor = (
660+
next(iter(self.quantization_compressor.values()))
661+
if self.quantization_compressor is not None
662+
else None
663+
)
624664

625665
if (
626666
self.sparsity_compressor is not None
627667
and self.sparsity_config.format != CompressionFormat.dense.value
628668
):
629669
# note - decompress only supports one compressor atm
630-
quant_compressor = next(iter(self.quantization_compressor.values()))
631670
params_to_ignore = None
632-
if self.quantization_compressor is not None:
671+
if quant_compressor is not None:
633672
params_to_ignore = quant_compressor.compression_param_names
634673
# Sparse decompression is applied on the model_path
635674
# The compressor will try and load any quantization parameters as well

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
443443
model_stub, torch_dtype=torch.float32
444444
)
445445
reference_compressor = ModelCompressor.from_pretrained_model(
446-
cpu_model, s_config, q_format
446+
cpu_model, s_config, [q_format]
447447
)
448448
# Only stores dtype because meta model does not store values
449449
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
@@ -459,7 +459,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
459459
module.to_empty(device="meta")
460460

461461
# Compress in-place on meta model
462-
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
462+
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format])
463463
compressor.compress_model(meta_model)
464464

465465
# Compare keys and dtypes

0 commit comments

Comments
 (0)