Skip to content

[Quantization] Support more than one quant-compressor #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 109 additions & 25 deletions src/compressed_tensors/compressors/model_compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def from_pretrained_model(
cls,
model: Module,
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
quantization_format: Optional[str] = None,
quantization_format: Optional[Union[str, List[str]]] = None,
) -> Optional["ModelCompressor"]:
"""
Given a pytorch model and optional sparsity and/or quantization configs,
Expand All @@ -182,7 +182,22 @@ def from_pretrained_model(
algorithm
:return: compressor for the configs, or None if model is not compressed
"""
# reconstruct config from schemes attached to modules

compression_formats = None
if quantization_format is not None:
# llmcompressor incorrectly passes in a CompressionFormat when
# the value string is expected - handle both cases
if isinstance(quantization_format, (str, CompressionFormat)):
quantization_format = [quantization_format]

compression_formats = quantization_format
Copy link
Contributor

@kylesayrs kylesayrs Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this parsing logic is duplicated in from_pretrained_model and decompress_model.

# assume multiple compression formats means mixed-precision
# as we currently only support one compressor per precision type and scheme
if len(quantization_format) > 1:
quantization_format = CompressionFormat.mixed_precision.value
else:
quantization_format = quantization_format[0]

quantization_config = QuantizationConfig.from_pretrained(
model, format=quantization_format
)
Expand All @@ -203,6 +218,7 @@ def from_pretrained_model(
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transform_config=transform_config,
compression_formats=compression_formats,
)

@staticmethod
Expand Down Expand Up @@ -263,30 +279,55 @@ def parse_quantization_config(

return quantization_config

def _fetch_unique_quantization_formats(self) -> List[str]:
"""
Get all unique compression formats present in a model
:return: list of quantization formats
"""
quantization_formats = []
for _, scheme in self.quantization_config.config_groups.items():
if scheme.format is not None and scheme.format not in quantization_formats:
quantization_formats.append(scheme.format)

# If empty list, fallback to using the global format
if len(quantization_formats) == 0:
quantization_formats.append(self.quantization_config.format)
Copy link
Contributor

@kylesayrs kylesayrs Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.quantization_config.format is nullable afaict, please add logic and/or typehint to account for this

return quantization_formats

def __init__(
self,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_config: Optional[QuantizationConfig] = None,
transform_config: Optional[TransformConfig] = None,
compression_formats: Optional[List[str]] = None,
):
self.sparsity_config = sparsity_config
self.quantization_config = quantization_config
self.transform_config = transform_config
self.compression_formats = compression_formats

self.sparsity_compressor = None
self.quantization_compressor: Optional[
Union[BaseQuantizationCompressor, DenseCompressor]
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason this can't be renamed to indicate it is a map of compressors instead of a single compressor?

] = None
# no transform compressor is required

if sparsity_config is not None:
self.sparsity_compressor = BaseCompressor.load_from_registry(
sparsity_config.format, config=sparsity_config
)

if quantization_config is not None:
self.quantization_compressor = BaseCompressor.load_from_registry(
quantization_config.format, config=quantization_config
)
if not self.compression_formats:
self.compression_formats = self._fetch_unique_quantization_formats()

self.quantization_compressor = {}
for format in self.compression_formats:
self.quantization_compressor[
format
] = BaseCompressor.load_from_registry(
format, config=quantization_config
)

# ----- used by hf quantizer ----- #

Expand Down Expand Up @@ -381,12 +422,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
targets=scheme.targets,
ignore=self.quantization_config.ignore,
)
unexpected_keys.update(
merge_names(target, param)
for target in quant_targets
for param in self.quantization_compressor.compression_param_names
if param != "weight"
)
for quant_compressor in self.quantization_compressor.values():
unexpected_keys.update(
merge_names(target, param)
for target in quant_targets
for param in quant_compressor.compression_param_names
if param != "weight"
)

return list(unexpected_keys)

Expand Down Expand Up @@ -424,7 +466,25 @@ def compress_model(self, model: Module):

# quantization first
if prefix in module_to_scheme:
state_dict = self.quantization_compressor.compress(
if (
not hasattr(module.quantization_scheme, "format")
or module.quantization_scheme.format is None
):
if (
self.quantization_config.format
== CompressionFormat.mixed_precision.value
):
raise ValueError(
"Compressing mixed-precision models without defining "
"per module quantization_scheme.format is currently "
"not supported"
)
format = self.quantization_config.format
else:
format = module.quantization_scheme.format

quant_compressor = self.quantization_compressor.get(format)
state_dict = quant_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
Expand Down Expand Up @@ -495,12 +555,28 @@ def decompress_model(self, model: Module):

# quantization second
if prefix in module_to_scheme:
state_dict = (
self.quantization_compressor.decompress_module_from_state_dict(
prefix,
state_dict,
scheme=module_to_scheme[prefix],
)

if (
not hasattr(module.quantization_scheme, "format")
or module.quantization_scheme.format is None
):
if (
self.quantization_config.format
== CompressionFormat.mixed_precision.value
):
raise ValueError(
"Decompressing mixed-precision models without defining "
"per module quantization_scheme.format is currently not "
"supported"
)
format = self.quantization_config.format
else:
format = module.quantization_scheme.format
quant_compressor = self.quantization_compressor.get(format)
state_dict = quant_compressor.decompress_module_from_state_dict(
prefix,
state_dict,
scheme=module_to_scheme[prefix],
)

# remove any existing parameters
Expand Down Expand Up @@ -539,7 +615,9 @@ def compress(

if self.quantization_compressor is not None:
module_to_scheme = map_module_to_scheme(model)
state_dict = self.quantization_compressor.compress(
# Note - compress only supports one compression format atm
quant_compressor = next(iter(self.quantization_compressor.values()))
state_dict = quant_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=show_progress,
Expand Down Expand Up @@ -588,14 +666,20 @@ def decompress(self, model_path: str, model: Module):
"""
model_path = get_safetensors_folder(model_path)
sparse_decompressed = False
quant_compressor = (
next(iter(self.quantization_compressor.values()))
if self.quantization_compressor is not None
else None
)

if (
self.sparsity_compressor is not None
and self.sparsity_config.format != CompressionFormat.dense.value
):
# note - decompress only supports one compressor atm
params_to_ignore = None
if self.quantization_compressor is not None:
params_to_ignore = self.quantization_compressor.compression_param_names
if quant_compressor is not None:
params_to_ignore = quant_compressor.compression_param_names
# Sparse decompression is applied on the model_path
# The compressor will try and load any quantization parameters as well
# params_to_skip_load will skip over quantization params from being loaded
Expand All @@ -606,7 +690,7 @@ def decompress(self, model_path: str, model: Module):
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if self.quantization_compressor is not None:
if quant_compressor is not None:
# Temporarily set quantization status to FROZEN to prevent
# quantization during apply_quantization_config. This ensures
# that the dtypes of the weights are not unintentionally updated.
Expand All @@ -629,15 +713,15 @@ def decompress(self, model_path: str, model: Module):
# including initialization
load_weight_quantization=(
sparse_decompressed
or isinstance(self.quantization_compressor, DenseCompressor)
or isinstance(quant_compressor, DenseCompressor)
),
)

model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = self.quantization_compressor.decompress(
dense_gen = quant_compressor.decompress(
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
# TODO: all weight quantization params will be moved to the compressor
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CompressionFormat(Enum):
naive_quantized = "naive-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
mixed_precision = "mixed-precision"
nvfp4_pack_quantized = "nvfp4-pack-quantized"


Expand Down
3 changes: 3 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from copy import deepcopy
from typing import List, Optional

from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import (
DynamicType,
QuantizationArgs,
Expand All @@ -42,12 +43,14 @@ class QuantizationScheme(BaseModel):
:param weights: quantization config for layer weights
:param input_activations: quantization config for layer inputs
:param output_activations: quantization config for layer outputs
:param format: CompressionFormat for the layer
"""

targets: List[str]
weights: Optional[QuantizationArgs] = None
input_activations: Optional[QuantizationArgs] = None
output_activations: Optional[QuantizationArgs] = None
format: Optional[str] = None

@model_validator(mode="after")
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
Expand Down
51 changes: 46 additions & 5 deletions tests/test_compressors/model_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
import torch
import torch.nn as nn
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationConfig,
QuantizationScheme,
)
from safetensors.torch import save_file
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -395,7 +399,7 @@ def _get_combined_config(s_config, q_config):
)
def test_compress_model(model_stub, q_format, s_config, tmpdir):
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format])

# compress model by eagerly compressing state dict
true_compressed = dict(compressor.compress(model))
Expand Down Expand Up @@ -443,7 +447,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
model_stub, torch_dtype=torch.float32
)
reference_compressor = ModelCompressor.from_pretrained_model(
cpu_model, s_config, q_format
cpu_model, s_config, [q_format]
)
# Only stores dtype because meta model does not store values
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
Expand All @@ -459,7 +463,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
module.to_empty(device="meta")

# Compress in-place on meta model
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format])
compressor.compress_model(meta_model)

# Compare keys and dtypes
Expand All @@ -469,6 +473,43 @@ def test_compress_model_meta(model_stub, q_format, s_config):
assert compressed[key].dtype == dtype, f"{key} has incorrect dtype"


def test_multiple_quant_compressors():
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2, 3))
input_activations = QuantizationArgs(num_bits=8, type="float")
weights = QuantizationArgs(num_bits=8, type="float")

scheme_fp8 = QuantizationScheme(
targets=["Linear"],
weights=weights,
input_activations=input_activations,
format=CompressionFormat.float_quantized.value,
)

input_activations = QuantizationArgs(num_bits=4, type="float")
weights = QuantizationArgs(num_bits=4, type="float")

scheme_nvfp4 = QuantizationScheme(
targets=["Linear"],
weights=weights,
input_activations=input_activations,
format=CompressionFormat.nvfp4_pack_quantized.value,
)

model[0].quantization_scheme = scheme_fp8
model[0].quantization_status = "frozen"
model[1].quantization_scheme = scheme_nvfp4
model[1].quantization_status = "frozen"

formats = [scheme_fp8.format, scheme_nvfp4.format]

compressor = ModelCompressor.from_pretrained_model(model, None, formats)
assert isinstance(compressor.quantization_compressor, dict)
assert (
compressor.quantization_config.format == CompressionFormat.mixed_precision.value
)
assert all(format in compressor.quantization_compressor for format in formats)


@pytest.mark.parametrize(
"model_stub,comp_stub",
[
Expand Down
6 changes: 5 additions & 1 deletion tests/test_quantization/test_quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,27 @@ def test_basic_scheme():
assert scheme.weights == weights
assert scheme.input_activations is None
assert scheme.output_activations is None
assert scheme.format is None


def test_full_scheme():
targets = ["Linear"]
weights = QuantizationArgs()
input_activations = QuantizationArgs(num_bits=4)
input_activations = QuantizationArgs(num_bits=8)
output_activations = QuantizationArgs(num_bits=8, type="float", symmetric=False)

scheme = QuantizationScheme(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
format="float-quantized",
)
assert scheme.targets == targets
assert scheme.weights == weights
assert scheme.input_activations == input_activations
assert scheme.output_activations == output_activations
assert scheme.format is "float-quantized"


def test_needs_targets():
Expand All @@ -57,3 +60,4 @@ def test_defaults():
assert output.weights is None
assert output.input_activations is None
assert output.output_activations is None
assert output.format is None