Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
df873fb
squashed/rebased
brian-dellabetta Aug 21, 2025
2a58648
cleanup
brian-dellabetta Aug 21, 2025
7cdd1cd
remove TODO
brian-dellabetta Aug 21, 2025
78d274d
more clenaup
brian-dellabetta Aug 21, 2025
606f177
cleanup
brian-dellabetta Aug 21, 2025
829b7cb
formatting
brian-dellabetta Aug 21, 2025
7f2c5de
formatting
brian-dellabetta Aug 21, 2025
712a731
merge main
brian-dellabetta Aug 28, 2025
b515c1b
resolve redundant merge code
brian-dellabetta Aug 28, 2025
0e11f93
style fixes
brian-dellabetta Aug 28, 2025
80844ea
cleanup / test fixes
brian-dellabetta Sep 3, 2025
cbac9f2
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 4, 2025
f62b70c
test fixes
brian-dellabetta Sep 4, 2025
ac1ce1c
formatting/touchups
brian-dellabetta Sep 9, 2025
0da8730
stylefix
brian-dellabetta Sep 9, 2025
5744d73
stylefixes
brian-dellabetta Sep 9, 2025
304615e
stylefixes
brian-dellabetta Sep 9, 2025
76f81b9
remaining test fixes
brian-dellabetta Sep 9, 2025
5bf957f
revert extraneous test change
brian-dellabetta Sep 9, 2025
360a9fb
remove test running code
brian-dellabetta Sep 9, 2025
cc27568
remove infer_quantization_status
brian-dellabetta Sep 9, 2025
fc2e102
lifecycle updates for overwriting config
brian-dellabetta Sep 10, 2025
14a359f
remove compress_quantized_weight, test fixes, remove sparseml references
brian-dellabetta Sep 10, 2025
dd87f23
merge main
brian-dellabetta Sep 10, 2025
595228c
drop frozen scale_dtype post-merge
brian-dellabetta Sep 10, 2025
49e4e92
formatting
brian-dellabetta Sep 10, 2025
fb32778
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 11, 2025
6ba47e5
clear previously initialized qparams
brian-dellabetta Sep 11, 2025
72e5f3d
remove apply_quantization_status
brian-dellabetta Sep 11, 2025
88b8865
stylefix
brian-dellabetta Sep 11, 2025
fb6aa9a
add ALL_QPARAM_KEYS var
brian-dellabetta Sep 11, 2025
d2903a1
multi-apply quantization config test
brian-dellabetta Sep 15, 2025
5776c86
multi-apply test cleanup
brian-dellabetta Sep 15, 2025
5e5ffb5
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 17, 2025
98a97e5
ALL_QPARAM_NAMES
brian-dellabetta Sep 17, 2025
02d5e78
stylefixes
brian-dellabetta Sep 17, 2025
7d8c5a4
exclude sparsity param names
brian-dellabetta Sep 17, 2025
01af659
QuantizationMetadata class
brian-dellabetta Sep 18, 2025
3fdd125
stylefix
brian-dellabetta Sep 18, 2025
b789adf
llm-compressor test fix
brian-dellabetta Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union

Expand Down Expand Up @@ -51,6 +50,7 @@
get_safetensors_folder,
has_offloaded_params,
merge_names,
patch_attr,
register_offload_parameter,
update_parameter_data,
)
Expand Down Expand Up @@ -201,9 +201,11 @@ def from_pretrained_model(
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transform_config=transform_config,
compression_formats=[quantization_format]
if isinstance(quantization_format, str)
else quantization_format,
compression_formats=(
[quantization_format]
if isinstance(quantization_format, str)
else quantization_format
),
)

@staticmethod
Expand Down Expand Up @@ -700,8 +702,10 @@ def decompress(self, model_path: str, model: Module):
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.

with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
with patch_attr(
self.quantization_config,
"quantization_status",
QuantizationStatus.FROZEN,
):
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
Expand Down Expand Up @@ -889,23 +893,3 @@ def new_dtype_byte_size(dtype):
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.

:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
36 changes: 17 additions & 19 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
KV_CACHE_TARGETS,
infer_quantization_status,
is_kv_cache_quant_scheme,
)
from compressed_tensors.utils.helpers import deprecated, replace_module
Expand Down Expand Up @@ -154,20 +153,21 @@ def apply_quantization_config(

# replace with run compressed if applicable
# FUTURE: move this to model compressor
if isinstance(submodule, torch.nn.Linear) and run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=format,
)
replace_module(model, name, compressed_linear)

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
if (
run_compressed
and isinstance(submodule, torch.nn.Linear)
and config.format != CompressionFormat.dense.value
):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=config.format,
)
replace_module(model, name, compressed_linear)

# apply current quantization status to each targeted submodule
apply_quantization_status(submodule, config.quantization_status)


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
Expand Down Expand Up @@ -214,9 +214,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
:param status: status to update the module to
"""

current_status = infer_quantization_status(model)

if status >= QuantizationStatus.INITIALIZED > current_status:
if status >= QuantizationStatus.INITIALIZED:
force_zero_point_init = status != QuantizationStatus.COMPRESSED

# When decompressing, we set the scale_dtype as the model's dtype
Expand All @@ -234,7 +232,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
)
)

if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
if status >= QuantizationStatus.COMPRESSED:
model.apply(compress_quantized_weights)


Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def compress_quantized_weights(module: Module):
# no quantization scheme or weights not quantized, nothing to do
return

if scheme is QuantizationStatus.COMPRESSED:
status = getattr(module, "quantization_status", None)
if status is QuantizationStatus.COMPRESSED:
# module is already compressed, nothing to do
return

Expand Down
16 changes: 0 additions & 16 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


__all__ = [
"infer_quantization_status",
"is_module_quantized",
"is_model_quantized",
"module_type",
Expand Down Expand Up @@ -234,21 +233,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
return q_min, q_max


def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
"""
Checks the quantization status of a model. Assumes all modules in the model have
the same status, so only the first quantized model is checked.

:param model: model to check quantization status for
:return: quantization status if the model is quantized, otherwise None
"""
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/utils/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def match_targets(
if isinstance(module, InternalModule):
return []

# The order of the output `matches` list matters, the are arranged from most
# The order of the output `matches` list matters, they are arranged from most
# specific to least specific, and this order will be used when merging configs.
# The entries are sorted in the following order:
# 1. matches on exact strings
# 2. matches on regex patterns
# 3. matches on module names
# 3. matches on module names (e.g. "Linear")

targets = sorted(targets, key=lambda x: ("re:" in x, x))
matched_targets = []
Expand Down
21 changes: 10 additions & 11 deletions tests/test_compressors/model_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,14 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None):
self.linear = nn.Linear(in_features, out_features, bias=False)

# Set the weights of the linear layer
self.linear.weight = nn.Parameter(weights, requires_grad=False)
self.linear.weight = nn.Parameter(weights.detach().clone())

# Attach weight_scale and weight_zero_point as parameters
if weight_scale is not None:
self.linear.weight_scale = nn.Parameter(
torch.tensor(weight_scale), requires_grad=False
)
self.linear.weight_scale = nn.Parameter(weight_scale.detach().clone())
if weight_zero_point is not None:
self.linear.weight_zero_point = nn.Parameter(
torch.tensor(weight_zero_point), requires_grad=False
weight_zero_point.detach().clone()
)

def forward(self, x):
Expand Down Expand Up @@ -443,9 +441,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
)
def test_compress_model_meta(model_stub, q_format, s_config):
# Load model on CPU to get expected compressed state_dict
cpu_model = AutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.float32
)
cpu_model = AutoModelForCausalLM.from_pretrained(model_stub)
reference_compressor = ModelCompressor.from_pretrained_model(
cpu_model, s_config, [q_format]
)
Expand All @@ -455,7 +451,6 @@ def test_compress_model_meta(model_stub, q_format, s_config):
# Load model on meta device
meta_model = AutoModelForCausalLM.from_pretrained(
model_stub,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
for module in meta_model.modules():
Expand Down Expand Up @@ -566,8 +561,12 @@ def test_decompress_model(model_stub, comp_stub):
# equivalent to decompressing from disk
assert decompressed.keys() == true_decompressed.keys()
for key in decompressed.keys():
assert decompressed[key].dtype == true_decompressed[key].dtype
assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}"
assert (
decompressed[key].dtype == true_decompressed[key].dtype
), f"{key} dtypes not equal"
assert torch.all(
decompressed[key] == true_decompressed[key]
), f"{key} values not equal"


def remove_empty_weight_zero_points(state_dict):
Expand Down