Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
df873fb
squashed/rebased
brian-dellabetta Aug 21, 2025
2a58648
cleanup
brian-dellabetta Aug 21, 2025
7cdd1cd
remove TODO
brian-dellabetta Aug 21, 2025
78d274d
more clenaup
brian-dellabetta Aug 21, 2025
606f177
cleanup
brian-dellabetta Aug 21, 2025
829b7cb
formatting
brian-dellabetta Aug 21, 2025
7f2c5de
formatting
brian-dellabetta Aug 21, 2025
712a731
merge main
brian-dellabetta Aug 28, 2025
b515c1b
resolve redundant merge code
brian-dellabetta Aug 28, 2025
0e11f93
style fixes
brian-dellabetta Aug 28, 2025
80844ea
cleanup / test fixes
brian-dellabetta Sep 3, 2025
cbac9f2
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 4, 2025
f62b70c
test fixes
brian-dellabetta Sep 4, 2025
ac1ce1c
formatting/touchups
brian-dellabetta Sep 9, 2025
0da8730
stylefix
brian-dellabetta Sep 9, 2025
5744d73
stylefixes
brian-dellabetta Sep 9, 2025
304615e
stylefixes
brian-dellabetta Sep 9, 2025
76f81b9
remaining test fixes
brian-dellabetta Sep 9, 2025
5bf957f
revert extraneous test change
brian-dellabetta Sep 9, 2025
360a9fb
remove test running code
brian-dellabetta Sep 9, 2025
cc27568
remove infer_quantization_status
brian-dellabetta Sep 9, 2025
fc2e102
lifecycle updates for overwriting config
brian-dellabetta Sep 10, 2025
14a359f
remove compress_quantized_weight, test fixes, remove sparseml references
brian-dellabetta Sep 10, 2025
dd87f23
merge main
brian-dellabetta Sep 10, 2025
595228c
drop frozen scale_dtype post-merge
brian-dellabetta Sep 10, 2025
49e4e92
formatting
brian-dellabetta Sep 10, 2025
fb32778
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 11, 2025
6ba47e5
clear previously initialized qparams
brian-dellabetta Sep 11, 2025
72e5f3d
remove apply_quantization_status
brian-dellabetta Sep 11, 2025
88b8865
stylefix
brian-dellabetta Sep 11, 2025
fb6aa9a
add ALL_QPARAM_KEYS var
brian-dellabetta Sep 11, 2025
d2903a1
multi-apply quantization config test
brian-dellabetta Sep 15, 2025
5776c86
multi-apply test cleanup
brian-dellabetta Sep 15, 2025
5e5ffb5
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 17, 2025
98a97e5
ALL_QPARAM_NAMES
brian-dellabetta Sep 17, 2025
02d5e78
stylefixes
brian-dellabetta Sep 17, 2025
7d8c5a4
exclude sparsity param names
brian-dellabetta Sep 17, 2025
01af659
QuantizationMetadata class
brian-dellabetta Sep 18, 2025
3fdd125
stylefix
brian-dellabetta Sep 18, 2025
b789adf
llm-compressor test fix
brian-dellabetta Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/quantize_and_pack_int4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"outputs": [],
"source": [
"quantization_config_dict = {\n",
"\t\"quant_method\": \"sparseml\",\n",
"\t\"quant_method\": \"compressed-tensors\",\n",
"\t\"format\": \"pack-quantized\",\n",
"\t\"global_compression_ratio\": None,\n",
"\t\"config_groups\": {\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union

Expand Down Expand Up @@ -50,6 +49,7 @@
get_offloaded_device,
get_safetensors_folder,
has_offloaded_params,
patch_attr,
register_offload_parameter,
update_parameter_data,
)
Expand Down Expand Up @@ -200,9 +200,11 @@ def from_pretrained_model(
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transform_config=transform_config,
compression_formats=[quantization_format]
if isinstance(quantization_format, str)
else quantization_format,
compression_formats=(
[quantization_format]
if isinstance(quantization_format, str)
else quantization_format
),
)

@staticmethod
Expand Down Expand Up @@ -594,8 +596,10 @@ def decompress(self, model_path: str, model: Module):
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.

with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
with patch_attr(
self.quantization_config,
"quantization_status",
QuantizationStatus.FROZEN,
):
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
Expand Down Expand Up @@ -787,23 +791,3 @@ def new_dtype_byte_size(dtype):
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.

:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
71 changes: 21 additions & 50 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.lifecycle.compressed import (
compress_quantized_weights,
)
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
Expand All @@ -35,7 +32,6 @@
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
KV_CACHE_TARGETS,
infer_quantization_status,
is_kv_cache_quant_scheme,
)
from compressed_tensors.utils.helpers import deprecated, replace_module
Expand All @@ -49,7 +45,6 @@
__all__ = [
"load_pretrained_quantization_parameters",
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
]

Expand Down Expand Up @@ -154,20 +149,27 @@ def apply_quantization_config(

# replace with run compressed if applicable
# FUTURE: move this to model compressor
if isinstance(submodule, torch.nn.Linear) and run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=format,
)
replace_module(model, name, compressed_linear)

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
if (
run_compressed
and isinstance(submodule, torch.nn.Linear)
and config.format != CompressionFormat.dense.value
):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=config.format,
)
replace_module(model, name, compressed_linear)

else:
initialize_module_for_quantization(
submodule,
force_zero_point=config.quantization_status
!= QuantizationStatus.COMPRESSED,
)

submodule.quantization_status = config.quantization_status


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
Expand Down Expand Up @@ -206,29 +208,6 @@ def process_kv_cache_config(
return config


def apply_quantization_status(model: Module, status: QuantizationStatus):
"""
Applies in place the quantization lifecycle up to the given status

:param model: model to apply quantization to
:param status: status to update the module to
"""

current_status = infer_quantization_status(model)

if status >= QuantizationStatus.INITIALIZED > current_status:
force_zero_point_init = status != QuantizationStatus.COMPRESSED

model.apply(
lambda module: initialize_module_for_quantization(
module, force_zero_point=force_zero_point_init
)
)

if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
model.apply(compress_quantized_weights)


@deprecated(
message="This function is deprecated and will be removed in a future release."
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
Expand All @@ -254,14 +233,6 @@ def find_name_or_class_matches(
return match_targets(name, module, targets)


def _infer_status(model: Module) -> Optional[QuantizationStatus]:
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def _load_quant_args_from_mapping(
base_name: str, module_name: str, module: Module, mapping: Dict
):
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def compress_quantized_weights(module: Module):
# no quantization scheme or weights not quantized, nothing to do
return

if scheme is QuantizationStatus.COMPRESSED:
status = getattr(module, "quantization_status", None)
if status is QuantizationStatus.COMPRESSED:
# module is already compressed, nothing to do
return

Expand Down
36 changes: 33 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
from compressed_tensors.utils import (
delete_offload_parameter,
disable_hf_hook,
get_execution_device,
register_offload_parameter,
Expand All @@ -44,6 +45,7 @@
"initialize_module_for_quantization",
"is_attention_module",
"KVCacheScaleType",
"ALL_QPARAM_KEYS",
]


Expand All @@ -55,16 +57,29 @@ class KVCacheScaleType(Enum):
VALUE = "v_scale"


ALL_QPARAM_KEYS = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
f"{base_name}_{suffix}"
for base_name in ("input", "weight", "output")
for suffix in (
"global_scale",
"scale",
"zero_point",
"g_idx",
)
]


def initialize_module_for_quantization(
module: Module,
scheme: Optional[QuantizationScheme] = None,
force_zero_point: bool = True,
):
"""
attaches appropriate scales, zero points, and observers to a layer
given its target quantization scheme
Attaches appropriate scales, zero points, and observers to a layer
given its target quantization scheme.

apply to full model with `model.apply(initialize_module_for_quantization)`
Previously initialized scales and zero points will be removed from
module if they no longer apply to the scheme

:param module: module to set for calibration
:param scheme: scheme to use for quantization. if None is provided,
Expand All @@ -79,6 +94,8 @@ def initialize_module_for_quantization(
# no scheme passed and layer not targeted for quantization - skip
return

_clear_all_qparams(module)

if is_attention_module(module):
# quantized actions based on calltime status
_initialize_attn_scales(module)
Expand Down Expand Up @@ -134,6 +151,19 @@ def is_attention_module(module: Module):
)


def _clear_all_qparams(
module: Module,
):
"""
Clear all previously registered quantization parameters from module

:param module: module to clear qparams from
"""
for key in ALL_QPARAM_KEYS:
if hasattr(module, key):
delete_offload_parameter(module, key)


def _initialize_scale_zero_point(
module: Module,
base_name: str,
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ class QuantizationConfig(BaseModel):
:param config_groups: dict of QuantizationSchemes specifying the quantization
settings for each quantized layer. A group could also be a reference to
a predefined scheme name, mapped to a list of its target layers/classes
:param quant_method: a constant used to differentiate sparseML quantization from
other quantization configs
:param quant_method: a constant used to differentiate compressed-tensors
quantization from other quantization configs
:param format: specifies how the quantized model is stored on disk
:quantization_status: specifies the current status of all quantized layers. It is
assumed all layers are in the same state.
Expand Down
16 changes: 0 additions & 16 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


__all__ = [
"infer_quantization_status",
"is_module_quantized",
"is_model_quantized",
"module_type",
Expand Down Expand Up @@ -234,21 +233,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
return q_min, q_max


def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
"""
Checks the quantization status of a model. Assumes all modules in the model have
the same status, so only the first quantized model is checked.

:param model: model to check quantization status for
:return: quantization status if the model is quantized, otherwise None
"""
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
Expand Down
3 changes: 0 additions & 3 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ def infer_compressor_from_model_config(
return compressor


# TODO: There is already the same function in
# SparseML, should be moved to a shared location
# in the future
def fix_fsdp_module_name(name: str) -> str:
"""
Remove FSDP wrapper prefixes from a module name
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/utils/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def match_targets(
if isinstance(module, InternalModule):
return []

# The order of the output `matches` list matters, the are arranged from most
# The order of the output `matches` list matters, they are arranged from most
# specific to least specific, and this order will be used when merging configs.
# The entries are sorted in the following order:
# 1. matches on exact strings
# 2. matches on regex patterns
# 3. matches on module names
# 3. matches on module names (e.g. "Linear")

targets = sorted(targets, key=lambda x: ("re:" in x, x))
matched_targets = []
Expand Down
21 changes: 10 additions & 11 deletions tests/test_compressors/model_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,14 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None):
self.linear = nn.Linear(in_features, out_features, bias=False)

# Set the weights of the linear layer
self.linear.weight = nn.Parameter(weights, requires_grad=False)
self.linear.weight = nn.Parameter(weights.detach().clone())

# Attach weight_scale and weight_zero_point as parameters
if weight_scale is not None:
self.linear.weight_scale = nn.Parameter(
torch.tensor(weight_scale), requires_grad=False
)
self.linear.weight_scale = nn.Parameter(weight_scale.detach().clone())
if weight_zero_point is not None:
self.linear.weight_zero_point = nn.Parameter(
torch.tensor(weight_zero_point), requires_grad=False
weight_zero_point.detach().clone()
)

def forward(self, x):
Expand Down Expand Up @@ -388,9 +386,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
)
def test_compress_model_meta(model_stub, q_format, s_config):
# Load model on CPU to get expected compressed state_dict
cpu_model = AutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.float32
)
cpu_model = AutoModelForCausalLM.from_pretrained(model_stub)
reference_compressor = ModelCompressor.from_pretrained_model(
cpu_model, s_config, [q_format]
)
Expand All @@ -400,7 +396,6 @@ def test_compress_model_meta(model_stub, q_format, s_config):
# Load model on meta device
meta_model = AutoModelForCausalLM.from_pretrained(
model_stub,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
for module in meta_model.modules():
Expand Down Expand Up @@ -511,8 +506,12 @@ def test_decompress_model(model_stub, comp_stub):
# equivalent to decompressing from disk
assert decompressed.keys() == true_decompressed.keys()
for key in decompressed.keys():
assert decompressed[key].dtype == true_decompressed[key].dtype
assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}"
assert (
decompressed[key].dtype == true_decompressed[key].dtype
), f"{key} dtypes not equal"
assert torch.all(
decompressed[key] == true_decompressed[key]
), f"{key} values not equal"


def remove_empty_weight_zero_points(state_dict):
Expand Down
Loading