Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
df873fb
squashed/rebased
brian-dellabetta Aug 21, 2025
2a58648
cleanup
brian-dellabetta Aug 21, 2025
7cdd1cd
remove TODO
brian-dellabetta Aug 21, 2025
78d274d
more clenaup
brian-dellabetta Aug 21, 2025
606f177
cleanup
brian-dellabetta Aug 21, 2025
829b7cb
formatting
brian-dellabetta Aug 21, 2025
7f2c5de
formatting
brian-dellabetta Aug 21, 2025
712a731
merge main
brian-dellabetta Aug 28, 2025
b515c1b
resolve redundant merge code
brian-dellabetta Aug 28, 2025
0e11f93
style fixes
brian-dellabetta Aug 28, 2025
80844ea
cleanup / test fixes
brian-dellabetta Sep 3, 2025
cbac9f2
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 4, 2025
f62b70c
test fixes
brian-dellabetta Sep 4, 2025
ac1ce1c
formatting/touchups
brian-dellabetta Sep 9, 2025
0da8730
stylefix
brian-dellabetta Sep 9, 2025
5744d73
stylefixes
brian-dellabetta Sep 9, 2025
304615e
stylefixes
brian-dellabetta Sep 9, 2025
76f81b9
remaining test fixes
brian-dellabetta Sep 9, 2025
5bf957f
revert extraneous test change
brian-dellabetta Sep 9, 2025
360a9fb
remove test running code
brian-dellabetta Sep 9, 2025
cc27568
remove infer_quantization_status
brian-dellabetta Sep 9, 2025
fc2e102
lifecycle updates for overwriting config
brian-dellabetta Sep 10, 2025
14a359f
remove compress_quantized_weight, test fixes, remove sparseml references
brian-dellabetta Sep 10, 2025
dd87f23
merge main
brian-dellabetta Sep 10, 2025
595228c
drop frozen scale_dtype post-merge
brian-dellabetta Sep 10, 2025
49e4e92
formatting
brian-dellabetta Sep 10, 2025
fb32778
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 11, 2025
6ba47e5
clear previously initialized qparams
brian-dellabetta Sep 11, 2025
72e5f3d
remove apply_quantization_status
brian-dellabetta Sep 11, 2025
88b8865
stylefix
brian-dellabetta Sep 11, 2025
fb6aa9a
add ALL_QPARAM_KEYS var
brian-dellabetta Sep 11, 2025
d2903a1
multi-apply quantization config test
brian-dellabetta Sep 15, 2025
5776c86
multi-apply test cleanup
brian-dellabetta Sep 15, 2025
5e5ffb5
Merge branch 'main' into bdellabe/scoped-quant-status
brian-dellabetta Sep 17, 2025
98a97e5
ALL_QPARAM_NAMES
brian-dellabetta Sep 17, 2025
02d5e78
stylefixes
brian-dellabetta Sep 17, 2025
7d8c5a4
exclude sparsity param names
brian-dellabetta Sep 17, 2025
01af659
QuantizationMetadata class
brian-dellabetta Sep 18, 2025
3fdd125
stylefix
brian-dellabetta Sep 18, 2025
b789adf
llm-compressor test fix
brian-dellabetta Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/quantize_and_pack_int4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"outputs": [],
"source": [
"quantization_config_dict = {\n",
"\t\"quant_method\": \"sparseml\",\n",
"\t\"quant_method\": \"compressed-tensors\",\n",
"\t\"format\": \"pack-quantized\",\n",
"\t\"global_compression_ratio\": None,\n",
"\t\"config_groups\": {\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import operator
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union

Expand Down Expand Up @@ -50,6 +49,7 @@
get_offloaded_device,
get_safetensors_folder,
has_offloaded_params,
patch_attr,
register_offload_parameter,
update_parameter_data,
)
Expand Down Expand Up @@ -200,9 +200,11 @@ def from_pretrained_model(
sparsity_config=sparsity_config,
quantization_config=quantization_config,
transform_config=transform_config,
compression_formats=[quantization_format]
if isinstance(quantization_format, str)
else quantization_format,
compression_formats=(
[quantization_format]
if isinstance(quantization_format, str)
else quantization_format
),
)

@staticmethod
Expand Down Expand Up @@ -594,8 +596,10 @@ def decompress(self, model_path: str, model: Module):
# that the dtypes of the weights are not unintentionally updated.
# The status is restored after quantization params are loaded.

with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
with patch_attr(
self.quantization_config,
"quantization_status",
QuantizationStatus.FROZEN,
):
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
Expand Down Expand Up @@ -787,23 +791,3 @@ def new_dtype_byte_size(dtype):
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


@contextmanager
def override_quantization_status(
config: QuantizationConfig, status: QuantizationStatus
):
"""
Within this context, the quantization status will be set to the
supplied status. After the context exits, the original status
will be restored.

:param config: the quantization config to override
:param status: the status to temporarily set
"""
original_status = config.quantization_status
config.quantization_status = status
try:
yield
finally:
config.quantization_status = original_status
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@

from .quant_args import *
from .quant_config import *
from .quant_metadata import *
from .quant_scheme import *
from .lifecycle import *
71 changes: 21 additions & 50 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.lifecycle.compressed import (
compress_quantized_weights,
)
from compressed_tensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
Expand All @@ -35,7 +32,6 @@
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
KV_CACHE_TARGETS,
infer_quantization_status,
is_kv_cache_quant_scheme,
)
from compressed_tensors.utils.helpers import deprecated, replace_module
Expand All @@ -49,7 +45,6 @@
__all__ = [
"load_pretrained_quantization_parameters",
"apply_quantization_config",
"apply_quantization_status",
"find_name_or_class_matches",
]

Expand Down Expand Up @@ -154,20 +149,27 @@ def apply_quantization_config(

# replace with run compressed if applicable
# FUTURE: move this to model compressor
if isinstance(submodule, torch.nn.Linear) and run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=format,
)
replace_module(model, name, compressed_linear)

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
if (
run_compressed
and isinstance(submodule, torch.nn.Linear)
and config.format != CompressionFormat.dense.value
):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=config.format,
)
replace_module(model, name, compressed_linear)

else:
initialize_module_for_quantization(
submodule,
force_zero_point=config.quantization_status
!= QuantizationStatus.COMPRESSED,
)

submodule.quantization_status = config.quantization_status


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
Expand Down Expand Up @@ -206,29 +208,6 @@ def process_kv_cache_config(
return config


def apply_quantization_status(model: Module, status: QuantizationStatus):
"""
Applies in place the quantization lifecycle up to the given status

:param model: model to apply quantization to
:param status: status to update the module to
"""

current_status = infer_quantization_status(model)

if status >= QuantizationStatus.INITIALIZED > current_status:
force_zero_point_init = status != QuantizationStatus.COMPRESSED

model.apply(
lambda module: initialize_module_for_quantization(
module, force_zero_point=force_zero_point_init
)
)

if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
model.apply(compress_quantized_weights)


@deprecated(
message="This function is deprecated and will be removed in a future release."
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
Expand All @@ -254,14 +233,6 @@ def find_name_or_class_matches(
return match_targets(name, module, targets)


def _infer_status(model: Module) -> Optional[QuantizationStatus]:
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def _load_quant_args_from_mapping(
base_name: str, module_name: str, module: Module, mapping: Dict
):
Expand Down
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def compress_quantized_weights(module: Module):
# no quantization scheme or weights not quantized, nothing to do
return

if scheme is QuantizationStatus.COMPRESSED:
status = getattr(module, "quantization_status", None)
if status is QuantizationStatus.COMPRESSED:
# module is already compressed, nothing to do
return

Expand Down
30 changes: 14 additions & 16 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@
import logging
import math
import warnings
from enum import Enum
from typing import Optional

import torch
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.quant_args import (
from compressed_tensors.quantization import (
FP8_E4M3_DATA,
ActivationOrdering,
KVCacheScaleType,
QuantizationArgs,
QuantizationMetadata,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
from compressed_tensors.utils import (
disable_hf_hook,
Expand All @@ -43,28 +44,23 @@
__all__ = [
"initialize_module_for_quantization",
"is_attention_module",
"KVCacheScaleType",
]


_LOGGER = logging.getLogger(__name__)


class KVCacheScaleType(Enum):
KEY = "k_scale"
VALUE = "v_scale"


def initialize_module_for_quantization(
module: Module,
scheme: Optional[QuantizationScheme] = None,
force_zero_point: bool = True,
):
"""
attaches appropriate scales, zero points, and observers to a layer
given its target quantization scheme
Attaches appropriate scales, zero points, and observers to a layer
given its target quantization scheme.

apply to full model with `model.apply(initialize_module_for_quantization)`
Previously initialized scales and zero points will be removed from
module if they no longer apply to the scheme

:param module: module to set for calibration
:param scheme: scheme to use for quantization. if None is provided,
Expand All @@ -79,6 +75,8 @@ def initialize_module_for_quantization(
# no scheme passed and layer not targeted for quantization - skip
return

QuantizationMetadata.clear_all_qparams(module)

if is_attention_module(module):
# quantized actions based on calltime status
_initialize_attn_scales(module)
Expand Down
7 changes: 4 additions & 3 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ class QuantizationConfig(BaseModel):
:param config_groups: dict of QuantizationSchemes specifying the quantization
settings for each quantized layer. A group could also be a reference to
a predefined scheme name, mapped to a list of its target layers/classes
:param quant_method: a constant used to differentiate sparseML quantization from
other quantization configs
:param quant_method: a constant used to differentiate compressed-tensors
quantization from other quantization configs
:param format: specifies how the quantized model is stored on disk
:quantization_status: specifies the current status of all quantized layers. It is
assumed all layers are in the same state.
Expand Down Expand Up @@ -185,7 +185,8 @@ def from_pretrained(
ignore[layer_type] = []
ignore[layer_type].append(name)
else:
quantization_status = submodule.quantization_status
if hasattr(submodule, "quantization_status"):
quantization_status = submodule.quantization_status
scheme = submodule.quantization_scheme
quantization_type_names.add(layer_type)

Expand Down
62 changes: 62 additions & 0 deletions src/compressed_tensors/quantization/quant_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum

from compressed_tensors.utils import delete_offload_parameter
from torch.nn import Module


__all__ = ["QuantizationMetadata", "KVCacheScaleType"]


class KVCacheScaleType(Enum):
KEY = "k_scale"
VALUE = "v_scale"


class QuantizationMetadata:
"""
Container class for metadata related to quantization
"""

@staticmethod
def all_qparam_names():
"""
All quantization parameter names that might be registered
onto a module during lifecycle (excluding serialized parameters)
"""
return [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
f"{base_name}_{suffix}"
for base_name in ("input", "weight", "output")
for suffix in (
"global_scale",
"scale",
"zero_point",
"g_idx",
)
]

@classmethod
def clear_all_qparams(cls, module: Module):
"""
Remove all parameters related to quantization that might have
been registered onto a module previously in lifecycle (excluding
serialized parameters)

:param module: Module to clear
"""
for key in cls.all_qparam_names():
if hasattr(module, key):
delete_offload_parameter(module, key)
16 changes: 0 additions & 16 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


__all__ = [
"infer_quantization_status",
"is_module_quantized",
"is_model_quantized",
"module_type",
Expand Down Expand Up @@ -234,21 +233,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
return q_min, q_max


def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
"""
Checks the quantization status of a model. Assumes all modules in the model have
the same status, so only the first quantized model is checked.

:param model: model to check quantization status for
:return: quantization status if the model is quantized, otherwise None
"""
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
Expand Down
Loading