Skip to content

Commit 22636cf

Browse files
committed
handle mixed-precision case
1 parent cb1c427 commit 22636cf

File tree

2 files changed

+60
-21
lines changed

2 files changed

+60
-21
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def from_pretrained_model(
164164
cls,
165165
model: Module,
166166
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
167-
quantization_format: Optional[List[str]] = None,
167+
quantization_format: Optional[Union[str, List[str]]] = None,
168168
) -> Optional["ModelCompressor"]:
169169
"""
170170
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -179,10 +179,14 @@ def from_pretrained_model(
179179
"""
180180
# assume multiple compression formats means mixed-precision
181181
# as we currently only support one compressor per precision type and scheme
182-
if len(quantization_format) > 1:
183-
quantization_format = CompressionFormat.mixed_precision.value
184-
else:
185-
quantization_format = quantization_format[0]
182+
if quantization_format is not None:
183+
if isinstance(quantization_format, str):
184+
quantization_format = [quantization_format]
185+
186+
if len(quantization_format) > 1:
187+
quantization_format = CompressionFormat.mixed_precision.value
188+
else:
189+
quantization_format = quantization_format[0]
186190

187191
quantization_config = QuantizationConfig.from_pretrained(
188192
model, format=quantization_format
@@ -393,12 +397,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
393397
targets=scheme.targets,
394398
ignore=self.quantization_config.ignore,
395399
)
396-
unexpected_keys.update(
397-
merge_names(target, param)
398-
for target in quant_targets
399-
for param in self.quantization_compressor.compression_param_names
400-
if param != "weight"
401-
)
400+
for quant_compressor in self.quantization_compressor.values():
401+
unexpected_keys.update(
402+
merge_names(target, param)
403+
for target in quant_targets
404+
for param in quant_compressor.compression_param_names
405+
if param != "weight"
406+
)
402407

403408
return list(unexpected_keys)
404409

@@ -436,9 +441,24 @@ def compress_model(self, model: Module):
436441

437442
# quantization first
438443
if prefix in module_to_scheme:
439-
quant_compressor = self.quantization_compressor.get(
440-
module.quantization_scheme.format
441-
)
444+
if (
445+
not hasattr(module.quantization_scheme, "format")
446+
or module.quantization_scheme.format is None
447+
):
448+
if (
449+
self.quantization_config.format
450+
== CompressionFormat.mixed_precision.value
451+
):
452+
raise ValueError(
453+
"Compressing mixed-precision models without defining "
454+
"per module quantization_scheme.format is currently "
455+
"not supported"
456+
)
457+
format = self.quantization_config.format
458+
else:
459+
format = module.quantization_scheme.format
460+
461+
quant_compressor = self.quantization_compressor.get(format)
442462
state_dict = quant_compressor.compress(
443463
state_dict,
444464
names_to_scheme=module_to_scheme,
@@ -510,9 +530,24 @@ def decompress_model(self, model: Module):
510530

511531
# quantization second
512532
if prefix in module_to_scheme:
513-
quant_compressor = self.quantization_compressor.get(
514-
module.quantization_scheme.format
515-
)
533+
534+
if (
535+
not hasattr(module.quantization_scheme, "format")
536+
or module.quantization_scheme.format is None
537+
):
538+
if (
539+
self.quantization_config.format
540+
== CompressionFormat.mixed_precision.value
541+
):
542+
raise ValueError(
543+
"Decompressing mixed-precision models without defining "
544+
"per module quantization_scheme.format is currently not "
545+
"supported"
546+
)
547+
format = self.quantization_config.format
548+
else:
549+
format = module.quantization_scheme.format
550+
quant_compressor = self.quantization_compressor.get(format)
516551
state_dict = quant_compressor.decompress_module_from_state_dict(
517552
prefix,
518553
state_dict,
@@ -606,15 +641,19 @@ def decompress(self, model_path: str, model: Module):
606641
"""
607642
model_path = get_safetensors_folder(model_path)
608643
sparse_decompressed = False
644+
quant_compressor = (
645+
next(iter(self.quantization_compressor.values()))
646+
if self.quantization_compressor is not None
647+
else None
648+
)
609649

610650
if (
611651
self.sparsity_compressor is not None
612652
and self.sparsity_config.format != CompressionFormat.dense.value
613653
):
614654
# note - decompress only supports one compressor atm
615-
quant_compressor = next(iter(self.quantization_compressor.values()))
616655
params_to_ignore = None
617-
if self.quantization_compressor is not None:
656+
if quant_compressor is not None:
618657
params_to_ignore = quant_compressor.compression_param_names
619658
# Sparse decompression is applied on the model_path
620659
# The compressor will try and load any quantization parameters as well

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
443443
model_stub, torch_dtype=torch.float32
444444
)
445445
reference_compressor = ModelCompressor.from_pretrained_model(
446-
cpu_model, s_config, q_format
446+
cpu_model, s_config, [q_format]
447447
)
448448
# Only stores dtype because meta model does not store values
449449
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
@@ -459,7 +459,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
459459
module.to_empty(device="meta")
460460

461461
# Compress in-place on meta model
462-
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
462+
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format])
463463
compressor.compress_model(meta_model)
464464

465465
# Compare keys and dtypes

0 commit comments

Comments
 (0)