Skip to content

[Quantization] Support more than one quant-compressor #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 108 additions & 25 deletions src/compressed_tensors/compressors/model_compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def from_pretrained_model(
cls,
model: Module,
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
quantization_format: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaict this is the only entrypoint for this function.

Why not just adjust the upstream function infer_quantization_format to infer the mixed value? Rather than supporting an extra data type (List[str]) which ideally should never actually appear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @kylesayrs on this, also if a list of quantization formats are passed in we override them to mixed precision format and then infer them again downstream?

Copy link
Collaborator Author

@dsikka dsikka Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. Separation of concern. The infer_quantization_format is responsible for inferring the formats in the model but what gets written to the config should be determined by the ModelCompressor class which is ultimately responsible for writing the quantization config

We dont infer again - we use the per module format attached to each scheme to compress each module.

See the updated llmcompressor functionality: vllm-project/llm-compressor#1713

Copy link
Contributor

@kylesayrs kylesayrs Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict the only reason why we would need to infer the list of used quantization formats in a model is to write to the config. I since model_compressor is responsible for writing to the config, I would argue that the "infer global quantization tag for the purposes of writing to config" logic should exist in model compressor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to pass all available formats, why are we then re inferring afterwards via _fetch_unique_quantization_formats? This seems like a potential conflict in source of truth.

Ideally scheme.format should be the source of truth of formats.

quantization_format: Optional[Union[str, List[str]]] = None,
) -> Optional["ModelCompressor"]:
"""
Given a pytorch model and optional sparsity and/or quantization configs,
Expand All @@ -182,7 +182,21 @@ def from_pretrained_model(
algorithm
:return: compressor for the configs, or None if model is not compressed
"""
# reconstruct config from schemes attached to modules

if quantization_format is not None:
# llmcompressor incorrectly passes in a CompressionFormat when
# the value string is expected - handle both cases
if isinstance(quantization_format, (str, CompressionFormat)):
quantization_format = [quantization_format]

compression_formats = quantization_format
# assume multiple compression formats means mixed-precision
# as we currently only support one compressor per precision type and scheme
if len(quantization_format) > 1:
quantization_format = CompressionFormat.mixed_precision.value
else:
quantization_format = quantization_format[0]

quantization_config = QuantizationConfig.from_pretrained(
model, format=quantization_format
)
Expand All @@ -203,6 +217,7 @@ def from_pretrained_model(
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transform_config=transform_config,
compression_formats=compression_formats,
)

@staticmethod
Expand Down Expand Up @@ -263,30 +278,55 @@ def parse_quantization_config(

return quantization_config

def _fetch_unique_quantization_formats(self) -> List[str]:
"""
Get all unique compression formats present in a model
:return: list of quantization formats
"""
quantization_formats = []
for _, scheme in self.quantization_config.config_groups.items():
if scheme.format is not None and scheme.format not in quantization_formats:
quantization_formats.append(scheme.format)

# If empty list, fallback to using the global format
if len(quantization_formats) == 0:
quantization_formats.append(self.quantization_config.format)
Copy link
Contributor

@kylesayrs kylesayrs Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.quantization_config.format is nullable afaict, please add logic and/or typehint to account for this

return quantization_formats

def __init__(
self,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_config: Optional[QuantizationConfig] = None,
transform_config: Optional[TransformConfig] = None,
compression_formats: Optional[List[str]] = None,
):
self.sparsity_config = sparsity_config
self.quantization_config = quantization_config
self.transform_config = transform_config
self.compression_formats = compression_formats

self.sparsity_compressor = None
self.quantization_compressor: Optional[
Union[BaseQuantizationCompressor, DenseCompressor]
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
] = None
Comment on lines 310 to 312
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename to self.quantization_compressors to indicate this is now a dict? or is there some reason we can't because it's serialized etc.?

# no transform compressor is required

if sparsity_config is not None:
self.sparsity_compressor = BaseCompressor.load_from_registry(
sparsity_config.format, config=sparsity_config
)

if quantization_config is not None:
self.quantization_compressor = BaseCompressor.load_from_registry(
quantization_config.format, config=quantization_config
)
if not self.compression_formats:
self.compression_formats = self._fetch_unique_quantization_formats()

self.quantization_compressor = {}
for format in self.compression_formats:
self.quantization_compressor[
format
] = BaseCompressor.load_from_registry(
format, config=quantization_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that a compressor is local to a config_group, but quantization_config is global, this no longer seems valid. Fortunately, the config is only ever referenced in Sparse24BitMaskCompressor here, so maybe this is ok for now

)

# ----- used by hf quantizer ----- #

Expand Down Expand Up @@ -381,12 +421,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
targets=scheme.targets,
ignore=self.quantization_config.ignore,
)
unexpected_keys.update(
merge_names(target, param)
for target in quant_targets
for param in self.quantization_compressor.compression_param_names
if param != "weight"
)
for quant_compressor in self.quantization_compressor.values():
unexpected_keys.update(
merge_names(target, param)
for target in quant_targets
for param in quant_compressor.compression_param_names
if param != "weight"
)

return list(unexpected_keys)

Expand Down Expand Up @@ -424,7 +465,25 @@ def compress_model(self, model: Module):

# quantization first
if prefix in module_to_scheme:
state_dict = self.quantization_compressor.compress(
if (
not hasattr(module.quantization_scheme, "format")
or module.quantization_scheme.format is None
):
if (
self.quantization_config.format
== CompressionFormat.mixed_precision.value
):
raise ValueError(
"Compressing mixed-precision models without defining "
"per module quantization_scheme.format is currently "
"not supported"
)
format = self.quantization_config.format
else:
format = module.quantization_scheme.format

quant_compressor = self.quantization_compressor.get(format)
state_dict = quant_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=False,
Expand Down Expand Up @@ -495,12 +554,28 @@ def decompress_model(self, model: Module):

# quantization second
if prefix in module_to_scheme:
state_dict = (
self.quantization_compressor.decompress_module_from_state_dict(
prefix,
state_dict,
scheme=module_to_scheme[prefix],
)

if (
not hasattr(module.quantization_scheme, "format")
or module.quantization_scheme.format is None
):
if (
self.quantization_config.format
== CompressionFormat.mixed_precision.value
):
raise ValueError(
"Decompressing mixed-precision models without defining "
"per module quantization_scheme.format is currently not "
"supported"
)
format = self.quantization_config.format
else:
format = module.quantization_scheme.format
quant_compressor = self.quantization_compressor.get(format)
state_dict = quant_compressor.decompress_module_from_state_dict(
prefix,
state_dict,
scheme=module_to_scheme[prefix],
)

# remove any existing parameters
Expand Down Expand Up @@ -539,7 +614,9 @@ def compress(

if self.quantization_compressor is not None:
module_to_scheme = map_module_to_scheme(model)
state_dict = self.quantization_compressor.compress(
# Note - compress only supports one compression format atm
quant_compressor = next(iter(self.quantization_compressor.values()))
state_dict = quant_compressor.compress(
state_dict,
names_to_scheme=module_to_scheme,
show_progress=show_progress,
Expand Down Expand Up @@ -588,14 +665,20 @@ def decompress(self, model_path: str, model: Module):
"""
model_path = get_safetensors_folder(model_path)
sparse_decompressed = False
quant_compressor = (
next(iter(self.quantization_compressor.values()))
if self.quantization_compressor is not None
else None
)

if (
self.sparsity_compressor is not None
and self.sparsity_config.format != CompressionFormat.dense.value
):
# note - decompress only supports one compressor atm
params_to_ignore = None
if self.quantization_compressor is not None:
params_to_ignore = self.quantization_compressor.compression_param_names
if quant_compressor is not None:
params_to_ignore = quant_compressor.compression_param_names
# Sparse decompression is applied on the model_path
# The compressor will try and load any quantization parameters as well
# params_to_skip_load will skip over quantization params from being loaded
Expand All @@ -606,7 +689,7 @@ def decompress(self, model_path: str, model: Module):
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
sparse_decompressed = True

if self.quantization_compressor is not None:
if quant_compressor is not None:
# Temporarily set quantization status to FROZEN to prevent
# quantization during apply_quantization_config. This ensures
# that the dtypes of the weights are not unintentionally updated.
Expand All @@ -629,15 +712,15 @@ def decompress(self, model_path: str, model: Module):
# including initialization
load_weight_quantization=(
sparse_decompressed
or isinstance(self.quantization_compressor, DenseCompressor)
or isinstance(quant_compressor, DenseCompressor)
),
)

model_path_or_state_dict = (
model.state_dict() if sparse_decompressed else model_path
)

dense_gen = self.quantization_compressor.decompress(
dense_gen = quant_compressor.decompress(
model_path_or_state_dict, names_to_scheme=names_to_scheme
)
# TODO: all weight quantization params will be moved to the compressor
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CompressionFormat(Enum):
naive_quantized = "naive-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
mixed_precision = "mixed-precision"
nvfp4_pack_quantized = "nvfp4-pack-quantized"


Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def from_pretrained(

if format is None:
if quantization_status == QuantizationStatus.COMPRESSED:
format = CompressionFormat.int_quantized.value
format = CompressionFormat.int_quantized.value # why?!
else:
format = CompressionFormat.dense.value

Expand Down
3 changes: 3 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from copy import deepcopy
from typing import List, Optional

from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import (
DynamicType,
QuantizationArgs,
Expand All @@ -42,12 +43,14 @@ class QuantizationScheme(BaseModel):
:param weights: quantization config for layer weights
:param input_activations: quantization config for layer inputs
:param output_activations: quantization config for layer outputs
:param format: CompressionFormat for the layer
"""

targets: List[str]
weights: Optional[QuantizationArgs] = None
input_activations: Optional[QuantizationArgs] = None
output_activations: Optional[QuantizationArgs] = None
format: Optional[str] = None

@model_validator(mode="after")
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
Expand Down
51 changes: 46 additions & 5 deletions tests/test_compressors/model_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
import torch
import torch.nn as nn
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationConfig,
QuantizationScheme,
)
from safetensors.torch import save_file
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -395,7 +399,7 @@ def _get_combined_config(s_config, q_config):
)
def test_compress_model(model_stub, q_format, s_config, tmpdir):
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format])

# compress model by eagerly compressing state dict
true_compressed = dict(compressor.compress(model))
Expand Down Expand Up @@ -443,7 +447,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
model_stub, torch_dtype=torch.float32
)
reference_compressor = ModelCompressor.from_pretrained_model(
cpu_model, s_config, q_format
cpu_model, s_config, [q_format]
)
# Only stores dtype because meta model does not store values
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
Expand All @@ -459,7 +463,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
module.to_empty(device="meta")

# Compress in-place on meta model
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format)
compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format])
compressor.compress_model(meta_model)

# Compare keys and dtypes
Expand All @@ -469,6 +473,43 @@ def test_compress_model_meta(model_stub, q_format, s_config):
assert compressed[key].dtype == dtype, f"{key} has incorrect dtype"


def test_multiple_quant_compressors():
model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2, 3))
input_activations = QuantizationArgs(num_bits=8, type="float")
weights = QuantizationArgs(num_bits=8, type="float")

scheme_fp8 = QuantizationScheme(
targets=["Linear"],
weights=weights,
input_activations=input_activations,
format=CompressionFormat.float_quantized.value,
)

input_activations = QuantizationArgs(num_bits=4, type="float")
weights = QuantizationArgs(num_bits=4, type="float")

scheme_nvfp4 = QuantizationScheme(
targets=["Linear"],
weights=weights,
input_activations=input_activations,
format=CompressionFormat.nvfp4_pack_quantized.value,
)

model[0].quantization_scheme = scheme_fp8
model[0].quantization_status = "frozen"
model[1].quantization_scheme = scheme_nvfp4
model[1].quantization_status = "frozen"

formats = [scheme_fp8.format, scheme_nvfp4.format]

compressor = ModelCompressor.from_pretrained_model(model, None, formats)
assert isinstance(compressor.quantization_compressor, dict)
assert (
compressor.quantization_config.format == CompressionFormat.mixed_precision.value
)
assert all(format in compressor.quantization_compressor for format in formats)


@pytest.mark.parametrize(
"model_stub,comp_stub",
[
Expand Down
Loading