Skip to content

Commit fac7e4a

Browse files
authored
[Quantization] Support more than one quant-compressor (#415)
* support more than one quant compressor * clean-up; add mixed-precision format * update * update * fix * handle mixed-precision case * update * update quant scheme tests * add tests * update * Update quant_config.py * clean-up
1 parent 33c52de commit fac7e4a

File tree

6 files changed

+154
-31
lines changed

6 files changed

+154
-31
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 93 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def from_pretrained_model(
169169
cls,
170170
model: Module,
171171
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
172-
quantization_format: Optional[str] = None,
172+
quantization_format: Optional[Union[str, List[str]]] = None,
173173
) -> Optional["ModelCompressor"]:
174174
"""
175175
Given a pytorch model and optional sparsity and/or quantization configs,
@@ -182,7 +182,6 @@ def from_pretrained_model(
182182
algorithm
183183
:return: compressor for the configs, or None if model is not compressed
184184
"""
185-
# reconstruct config from schemes attached to modules
186185
quantization_config = QuantizationConfig.from_pretrained(
187186
model, format=quantization_format
188187
)
@@ -203,6 +202,9 @@ def from_pretrained_model(
203202
sparsity_config=sparsity_config,
204203
quantization_config=quantization_config,
205204
transform_config=transform_config,
205+
compression_formats=[quantization_format]
206+
if isinstance(quantization_format, str)
207+
else quantization_format,
206208
)
207209

208210
@staticmethod
@@ -263,30 +265,61 @@ def parse_quantization_config(
263265

264266
return quantization_config
265267

268+
def _fetch_unique_quantization_formats(self) -> List[str]:
269+
"""
270+
Get all unique compression formats present in a model.
271+
:return: list of quantization formats
272+
"""
273+
quantization_formats = []
274+
for _, scheme in self.quantization_config.config_groups.items():
275+
if scheme.format is not None and scheme.format not in quantization_formats:
276+
quantization_formats.append(scheme.format)
277+
278+
if (
279+
len(quantization_formats) == 0
280+
and self.quantization_config.format
281+
!= CompressionFormat.mixed_precision.value
282+
):
283+
quantization_formats.append(self.quantization_config.format)
284+
return quantization_formats
285+
266286
def __init__(
267287
self,
268288
sparsity_config: Optional[SparsityCompressionConfig] = None,
269289
quantization_config: Optional[QuantizationConfig] = None,
270290
transform_config: Optional[TransformConfig] = None,
291+
compression_formats: Optional[List[str]] = None,
271292
):
272293
self.sparsity_config = sparsity_config
273294
self.quantization_config = quantization_config
274295
self.transform_config = transform_config
296+
self.compression_formats = compression_formats
275297

276298
self.sparsity_compressor = None
277299
self.quantization_compressor: Optional[
278-
Union[BaseQuantizationCompressor, DenseCompressor]
300+
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
279301
] = None
280302
# no transform compressor is required
281303

282304
if sparsity_config is not None:
283305
self.sparsity_compressor = BaseCompressor.load_from_registry(
284306
sparsity_config.format, config=sparsity_config
285307
)
308+
286309
if quantization_config is not None:
287-
self.quantization_compressor = BaseCompressor.load_from_registry(
288-
quantization_config.format, config=quantization_config
289-
)
310+
# If a list of compression_format is not provided, we resolve the
311+
# relevant quantization formats using the config groups from the config
312+
# and if those are not defined, we fall-back to the global quantization format
313+
if not self.compression_formats:
314+
self.compression_formats = self._fetch_unique_quantization_formats()
315+
316+
self.quantization_compressor = {}
317+
for format in self.compression_formats:
318+
self.quantization_compressor[
319+
format
320+
] = BaseCompressor.load_from_registry(
321+
format, config=quantization_config
322+
)
290323

291324
# ----- used by hf quantizer ----- #
292325

@@ -381,12 +414,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
381414
targets=scheme.targets,
382415
ignore=self.quantization_config.ignore,
383416
)
384-
unexpected_keys.update(
385-
merge_names(target, param)
386-
for target in quant_targets
387-
for param in self.quantization_compressor.compression_param_names
388-
if param != "weight"
389-
)
417+
for quant_compressor in self.quantization_compressor.values():
418+
unexpected_keys.update(
419+
merge_names(target, param)
420+
for target in quant_targets
421+
for param in quant_compressor.compression_param_names
422+
if param != "weight"
423+
)
390424

391425
return list(unexpected_keys)
392426

@@ -424,7 +458,21 @@ def compress_model(self, model: Module):
424458

425459
# quantization first
426460
if prefix in module_to_scheme:
427-
state_dict = self.quantization_compressor.compress(
461+
if (
462+
not hasattr(module.quantization_scheme, "format")
463+
or module.quantization_scheme.format is None
464+
):
465+
if len(self.compression_formats) > 1:
466+
raise ValueError(
467+
"Applying multiple compressors without defining "
468+
"per module formats is not supported "
469+
)
470+
format = self.compression_formats[0]
471+
else:
472+
format = module.quantization_scheme.format
473+
474+
quant_compressor = self.quantization_compressor.get(format)
475+
state_dict = quant_compressor.compress(
428476
state_dict,
429477
names_to_scheme=module_to_scheme,
430478
show_progress=False,
@@ -495,12 +543,24 @@ def decompress_model(self, model: Module):
495543

496544
# quantization second
497545
if prefix in module_to_scheme:
498-
state_dict = (
499-
self.quantization_compressor.decompress_module_from_state_dict(
500-
prefix,
501-
state_dict,
502-
scheme=module_to_scheme[prefix],
503-
)
546+
547+
if (
548+
not hasattr(module.quantization_scheme, "format")
549+
or module.quantization_scheme.format is None
550+
):
551+
if len(self.compression_formats) > 1:
552+
raise ValueError(
553+
"Applying multiple compressors without defining "
554+
"per module formats is not supported "
555+
)
556+
format = self.compression_formats[0]
557+
else:
558+
format = module.quantization_scheme.format
559+
quant_compressor = self.quantization_compressor.get(format)
560+
state_dict = quant_compressor.decompress_module_from_state_dict(
561+
prefix,
562+
state_dict,
563+
scheme=module_to_scheme[prefix],
504564
)
505565

506566
# remove any existing parameters
@@ -539,7 +599,9 @@ def compress(
539599

540600
if self.quantization_compressor is not None:
541601
module_to_scheme = map_module_to_scheme(model)
542-
state_dict = self.quantization_compressor.compress(
602+
# Note - compress only supports one compression format atm
603+
quant_compressor = next(iter(self.quantization_compressor.values()))
604+
state_dict = quant_compressor.compress(
543605
state_dict,
544606
names_to_scheme=module_to_scheme,
545607
show_progress=show_progress,
@@ -588,14 +650,20 @@ def decompress(self, model_path: str, model: Module):
588650
"""
589651
model_path = get_safetensors_folder(model_path)
590652
sparse_decompressed = False
653+
quant_compressor = (
654+
next(iter(self.quantization_compressor.values()))
655+
if self.quantization_compressor is not None
656+
else None
657+
)
591658

592659
if (
593660
self.sparsity_compressor is not None
594661
and self.sparsity_config.format != CompressionFormat.dense.value
595662
):
663+
# note - decompress only supports one compressor atm
596664
params_to_ignore = None
597-
if self.quantization_compressor is not None:
598-
params_to_ignore = self.quantization_compressor.compression_param_names
665+
if quant_compressor is not None:
666+
params_to_ignore = quant_compressor.compression_param_names
599667
# Sparse decompression is applied on the model_path
600668
# The compressor will try and load any quantization parameters as well
601669
# params_to_skip_load will skip over quantization params from being loaded
@@ -606,7 +674,7 @@ def decompress(self, model_path: str, model: Module):
606674
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
607675
sparse_decompressed = True
608676

609-
if self.quantization_compressor is not None:
677+
if quant_compressor is not None:
610678
# Temporarily set quantization status to FROZEN to prevent
611679
# quantization during apply_quantization_config. This ensures
612680
# that the dtypes of the weights are not unintentionally updated.
@@ -629,15 +697,15 @@ def decompress(self, model_path: str, model: Module):
629697
# including initialization
630698
load_weight_quantization=(
631699
sparse_decompressed
632-
or isinstance(self.quantization_compressor, DenseCompressor)
700+
or isinstance(quant_compressor, DenseCompressor)
633701
),
634702
)
635703

636704
model_path_or_state_dict = (
637705
model.state_dict() if sparse_decompressed else model_path
638706
)
639707

640-
dense_gen = self.quantization_compressor.decompress(
708+
dense_gen = quant_compressor.decompress(
641709
model_path_or_state_dict, names_to_scheme=names_to_scheme
642710
)
643711
# TODO: all weight quantization params will be moved to the compressor

src/compressed_tensors/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
3232
naive_quantized = "naive-quantized"
3333
pack_quantized = "pack-quantized"
3434
marlin_24 = "marlin-24"
35+
mixed_precision = "mixed-precision"
3536
nvfp4_pack_quantized = "nvfp4-pack-quantized"
3637

3738

src/compressed_tensors/quantization/quant_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ def from_pretrained(
234234
format = CompressionFormat.int_quantized.value
235235
else:
236236
format = CompressionFormat.dense.value
237+
elif isinstance(format, list):
238+
format = (
239+
CompressionFormat.mixed_precision.value
240+
if len(format) > 1
241+
else format[0]
242+
)
237243

238244
return QuantizationConfig(
239245
config_groups=config_groups,

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from copy import deepcopy
1717
from typing import List, Optional
1818

19+
from compressed_tensors.config import CompressionFormat
1920
from compressed_tensors.quantization.quant_args import (
2021
DynamicType,
2122
QuantizationArgs,
@@ -42,12 +43,14 @@ class QuantizationScheme(BaseModel):
4243
:param weights: quantization config for layer weights
4344
:param input_activations: quantization config for layer inputs
4445
:param output_activations: quantization config for layer outputs
46+
:param format: CompressionFormat for the layer
4547
"""
4648

4749
targets: List[str]
4850
weights: Optional[QuantizationArgs] = None
4951
input_activations: Optional[QuantizationArgs] = None
5052
output_activations: Optional[QuantizationArgs] = None
53+
format: Optional[str] = None
5154

5255
@model_validator(mode="after")
5356
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
import torch
2121
import torch.nn as nn
2222
from compressed_tensors.compressors import ModelCompressor
23-
from compressed_tensors.config import SparsityCompressionConfig
24-
from compressed_tensors.quantization import QuantizationConfig
23+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
24+
from compressed_tensors.quantization import (
25+
QuantizationArgs,
26+
QuantizationConfig,
27+
QuantizationScheme,
28+
)
2529
from safetensors.torch import save_file
2630
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
2731
from transformers import AutoModelForCausalLM
@@ -395,7 +399,7 @@ def _get_combined_config(s_config, q_config):
395399
)
396400
def test_compress_model(model_stub, q_format, s_config, tmpdir):
397401
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
398-
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
402+
compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format])
399403

400404
# compress model by eagerly compressing state dict
401405
true_compressed = dict(compressor.compress(model))
@@ -443,7 +447,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
443447
model_stub, torch_dtype=torch.float32
444448
)
445449
reference_compressor = ModelCompressor.from_pretrained_model(
446-
cpu_model, s_config, q_format
450+
cpu_model, s_config, [q_format]
447451
)
448452
# Only stores dtype because meta model does not store values
449453
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
@@ -459,7 +463,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
459463
module.to_empty(device="meta")
460464

461465
# Compress in-place on meta model
462-
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
466+
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format])
463467
compressor.compress_model(meta_model)
464468

465469
# Compare keys and dtypes
@@ -469,6 +473,43 @@ def test_compress_model_meta(model_stub, q_format, s_config):
469473
assert compressed[key].dtype == dtype, f"{key} has incorrect dtype"
470474

471475

476+
def test_multiple_quant_compressors():
477+
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2, 3))
478+
input_activations = QuantizationArgs(num_bits=8, type="float")
479+
weights = QuantizationArgs(num_bits=8, type="float")
480+
481+
scheme_fp8 = QuantizationScheme(
482+
targets=["Linear"],
483+
weights=weights,
484+
input_activations=input_activations,
485+
format=CompressionFormat.float_quantized.value,
486+
)
487+
488+
input_activations = QuantizationArgs(num_bits=4, type="float")
489+
weights = QuantizationArgs(num_bits=4, type="float")
490+
491+
scheme_nvfp4 = QuantizationScheme(
492+
targets=["Linear"],
493+
weights=weights,
494+
input_activations=input_activations,
495+
format=CompressionFormat.nvfp4_pack_quantized.value,
496+
)
497+
498+
model[0].quantization_scheme = scheme_fp8
499+
model[0].quantization_status = "frozen"
500+
model[1].quantization_scheme = scheme_nvfp4
501+
model[1].quantization_status = "frozen"
502+
503+
formats = [scheme_fp8.format, scheme_nvfp4.format]
504+
505+
compressor = ModelCompressor.from_pretrained_model(model, None, formats)
506+
assert isinstance(compressor.quantization_compressor, dict)
507+
assert (
508+
compressor.quantization_config.format == CompressionFormat.mixed_precision.value
509+
)
510+
assert all(format in compressor.quantization_compressor for format in formats)
511+
512+
472513
@pytest.mark.parametrize(
473514
"model_stub,comp_stub",
474515
[

tests/test_quantization/test_quant_scheme.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,27 @@ def test_basic_scheme():
2626
assert scheme.weights == weights
2727
assert scheme.input_activations is None
2828
assert scheme.output_activations is None
29+
assert scheme.format is None
2930

3031

3132
def test_full_scheme():
3233
targets = ["Linear"]
3334
weights = QuantizationArgs()
34-
input_activations = QuantizationArgs(num_bits=4)
35+
input_activations = QuantizationArgs(num_bits=8)
3536
output_activations = QuantizationArgs(num_bits=8, type="float", symmetric=False)
3637

3738
scheme = QuantizationScheme(
3839
targets=targets,
3940
weights=weights,
4041
input_activations=input_activations,
4142
output_activations=output_activations,
43+
format="float-quantized",
4244
)
4345
assert scheme.targets == targets
4446
assert scheme.weights == weights
4547
assert scheme.input_activations == input_activations
4648
assert scheme.output_activations == output_activations
49+
assert scheme.format is "float-quantized"
4750

4851

4952
def test_needs_targets():
@@ -57,3 +60,4 @@ def test_defaults():
5760
assert output.weights is None
5861
assert output.input_activations is None
5962
assert output.output_activations is None
63+
assert output.format is None

0 commit comments

Comments
 (0)