From 5a0ec31b66f74dd2998b423f8512b0e455e6a5e9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Aug 2025 21:41:20 -0400 Subject: [PATCH 01/42] quantization done Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 127 ++++++++++++++++ src/compressed_tensors/modeling/kvcache.py | 138 ++++++++++++++++++ .../quantization/lifecycle/apply.py | 113 +++++++------- .../transform/factory/base.py | 35 ++++- .../transform/factory/hadamard.py | 1 - .../transform/transform_args.py | 6 + .../transform/utils/matrix.py | 34 ++--- 7 files changed, 368 insertions(+), 86 deletions(-) create mode 100644 src/compressed_tensors/modeling/attention.py create mode 100644 src/compressed_tensors/modeling/kvcache.py diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 000000000..299816ef7 --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, Optional + +import torch +from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache +from compressed_tensors.quantization import QuantizationScheme, forward_quantize +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +_original_impl = "eager" # mutable + + +class QuantizedAttentionImpl(torch.nn.Module): + def __init__(self, attn_module: torch.nn.Module): + super().__init__() + self.attn_module_container = [attn_module] # avoid circular reference + self.quantization_enabled = False + + def forward( + self, + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, + ): + # quantization always gets applied last after hooks, in the same way that + # quantized `wrapped_forward` always applies quantization last + # because it does not use hooks + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + if scheme is not None and self.quantization_enabled: + if scheme.input_activations is not None: + query = forward_quantize(module, query, "q", scheme.input_activations) + + if scheme.weights is not None: + raise ValueError("") + + if scheme.output_activations is not None: + raise NotImplementedError("") + + return ALL_ATTENTION_FUNCTIONS[_original_impl]( + module, query, key, value, attention_mask, scaling, dropout, **kwargs + ) + + +def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): + if hasattr(module, "impl"): + return module.impl(module, *args, **kwargs) + else: + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) + + +def initialize_hooked_attention( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True +): + if not hasattr(module, "impl"): + module.register_module("impl", QuantizedAttentionImpl(module)) + + if model.config._attn_implementation != "ct_hooked_attention": + # assumes only one model at a time + global _original_impl + _original_impl = model.config._attn_implementation + + AttentionInterface.register("ct_hooked_attention", ct_hooked_attention) + model.config._attn_implementation = "ct_hooked_attention" + + if quantize: + # initialize q scale + + impl: QuantizedAttentionImpl = getattr(module, "impl") + impl.quantization_enabled = True + + initialize_hooked_kv_cache(module, quantize=True) + + +def register_query_hook(module: torch.nn.Module, func: Callable) -> RemovableHandle: + """ + Registers a forward pre-hook on `module.impl` that replaces the `query` argument + with `func(mod, query)` (handles both positional and keyword forms). + """ + impl = getattr(module, "impl") + + def _hook(mod: torch.nn.Module, args, kwargs): + # Keyword case + if "query" in kwargs: + kwargs["query"] = func(mod, kwargs["query"]) + return args, kwargs + + # Positional case: find the index of `query` in impl.forward + sig = inspect.signature(mod.forward) + param_names = tuple(sig.parameters.keys()) + try: + idx = param_names.index("query") + except ValueError: + # No `query` parameter; nothing to do + return args, kwargs + + if idx < len(args): + args = list(args) + args[idx] = func(mod, args[idx]) + return tuple(args), kwargs + + # Not present explicitly (maybe defaulted) + return args, kwargs + + return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py new file mode 100644 index 000000000..573b110f5 --- /dev/null +++ b/src/compressed_tensors/modeling/kvcache.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from compressed_tensors.quantization import ( + KVCacheScaleType, + QuantizationScheme, + forward_quantize, +) +from torch import Tensor +from torch.utils.hooks import RemovableHandle +from transformers import DynamicCache + + +class QuantizedKVCache(DynamicCache, torch.nn.Module): + def __init__(self, attn_module: torch.nn.Module): + DynamicCache.__init__(self) + torch.nn.Module.__init__(self) + self.attn_module_container = [attn_module] # avoid nn.Module circular reference + self.use_cache = False + self.quantization_enabled = False + + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + return self(*args, **kwargs) + + def forward( + self, + key_states: Tensor, + value_states: Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Tensor, Tensor]: + # quantization always gets applied last after hooks, in the same way that + # quantized `wrapped_forward` always applies quantization last + # because it does not use hooks + module = self.attn_module_container[0] + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + + if scheme is not None and self.quantization_enabled: + if scheme.input_activations is not None: + key_states = forward_quantize( + module, key_states, "k", scheme.input_activations + ) + value_states = forward_quantize( + module, value_states, "v", scheme.input_activations + ) + + if scheme.weights is not None: + raise ValueError("") + + if scheme.output_activations is not None: + raise NotImplementedError("") + + if self.use_cache: + return super().update(key_states, value_states, layer_idx, cache_kwargs) + else: + return key_states, value_states + + +def initialize_hooked_kv_cache(module: torch.nn.Module, quantize: bool = False): + if not hasattr(module, "kv_cache"): + module.register_module("kv_cache", QuantizedKVCache(module)) + module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) + + if quantize: + # initialize k scale + # initialize v scale + kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + kv_cache.quantization_enabled = True + + +def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): + kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + kwargs["past_key_value"] = kv_cache + + # use cache if cache is enabled, but this is typically not used during calibration + kv_cache.use_cache = kwargs.get("use_cache", False) + + return args, kwargs + + +def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + + def _hook(mod: torch.nn.Module, args, kwargs): + # If passed as keyword, this is easy. + if "key_states" in kwargs: + kwargs["key_states"] = hook(mod, kwargs["key_states"]) + return args, kwargs + + # Otherwise, find where key_states would be in positional args. + sig = inspect.signature(mod.forward) + param_names = tuple(sig.parameters.keys()) + try: + idx = param_names.index("key_states") + except ValueError: + # forward has no key_states parameter; do nothing + return args, kwargs + + # If the position exists in args, replace it. + if idx < len(args): + args = list(args) + args[idx] = hook(mod, args[idx]) + return tuple(args), kwargs + + # Not present positionally and not in kwargs (maybe defaulted) — do nothing. + return args, kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook( + module: torch.nn.Module, func: Callable, **kwargs +) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + + def hook(module: torch.nn.Module, args, kwargs): + signature = inspect.signature(module.forward) + bound_args = signature.bind_partial(*args, **kwargs) + return func(module, bound_args.arguments["value_states"]) + + return kv_cache.register_forward_pre_hook(hook, with_kwargs=True) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba6..71bc05f8f 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -14,14 +14,12 @@ import logging import re -from collections import OrderedDict, defaultdict from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType from typing import Set, Union import torch -from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.lifecycle.compressed import ( compress_quantized_weights, ) @@ -39,11 +37,12 @@ infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module +from compressed_tensors.utils import match_named_modules, replace_module from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel __all__ = [ @@ -116,7 +115,9 @@ def load_pretrained_quantization_parameters( def apply_quantization_config( - model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False + model: PreTrainedModel, + config: Union[QuantizationConfig, None], + run_compressed: bool = False, ) -> Dict[str, QuantizationScheme]: """ Initializes the model for quantization in-place based on the given config. @@ -130,68 +131,60 @@ def apply_quantization_config( # Workaround for when HF Quantizer passes None, see PR #180 if config is None: return dict() - - # remove reference to the original `config` - # argument. This function can mutate it, and we'd - # like to keep the original `config` as it is. config = deepcopy(config) - # build mapping of targets to schemes for easier matching - # use ordered dict to preserve target ordering in config - target_to_scheme = OrderedDict() + + # preprocessing for kv cache quantization + # TODO: KV cache-only uses this, attention uses standard targets + # perhaps the kv_cache targets have their own matching loop config = process_quantization_config(config) - names_to_scheme = dict() + for scheme in config.config_groups.values(): - for target in scheme.targets: - target_to_scheme[target] = scheme + for name, module in match_named_modules(model, scheme.targets): + # apply status + setattr(module, "quantization_status", config.quantization_status) + + if isinstance(module, torch.nn.Linear): + # can remove after meta model compression lands + force_zero_point_init = ( + config.quantization_status != QuantizationStatus.COMPRESSED + ) + scale_dtype = None + if config.quantization_status == QuantizationStatus.FROZEN: + if hasattr(model, "dtype"): + scale_dtype = model.dtype + + # add quantization parameters, wrap forward + setattr(module, "quantization_scheme", scheme) + initialize_module_for_quantization( + module, + force_zero_point=force_zero_point_init, + scale_dtype=scale_dtype, + ) - if run_compressed: - from compressed_tensors.linear.compressed_linear import CompressedLinear + # hopefully we can remove this soon + # avoid circular dep + from compressed_tensors.linear.compressed_linear import CompressedLinear + + if run_compressed: + compressed_linear = CompressedLinear.from_linear( + module, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + elif name.endswith("self_attn"): + # avoid circular dep + from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + ) - # list of submodules to ignore - ignored_submodules = defaultdict(list) - # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in model.named_modules(): - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) - if matches := find_name_or_class_matches(name, submodule, config.ignore): - for match in matches: - ignored_submodules[match].append(name) - continue # layer matches ignore list, continue - - targets = find_name_or_class_matches(name, submodule, target_to_scheme) - - if targets: - # mark modules to be quantized by adding - # quant scheme to the matching layers - scheme = _scheme_from_targets(target_to_scheme, targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme - - if config.ignore is not None and ignored_submodules is not None: - if set(config.ignore) - set(ignored_submodules): - _LOGGER.warning( - "Some layers that were to be ignored were " - "not found in the model: " - f"{set(config.ignore) - set(ignored_submodules)}" - ) + initialize_hooked_attention(model, module, quantize=True) + + else: + raise ValueError(f"Cannot quantize unknown module type {type(module)}") - # apply current quantization status across all targeted layers - apply_quantization_status(model, config.quantization_status) - return names_to_scheme + return {} # hopefully can remove soon def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index a77447093..5f9b7d082 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,6 +18,14 @@ import torch import torch.nn.utils.parametrize as P +from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + register_query_hook, +) +from compressed_tensors.modeling.kvcache import ( + initialize_hooked_kv_cache, + register_key_hook, +) from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -36,6 +44,7 @@ from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = ["TransformFactory", "TransformBase"] @@ -84,7 +93,7 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas """ raise NotImplementedError() - def apply_to_model(self, model: Module): + def apply_to_model(self, model: PreTrainedModel): """ Create transforms and apply them to the model @@ -92,11 +101,13 @@ def apply_to_model(self, model: Module): """ for arg in self.scheme.apply: for _, module in match_named_modules(model, arg.targets, arg.ignore): - self._apply_to_module(module, arg) + self._apply_to_module(module, arg, model) self._update_tied_weights() - def _apply_to_module(self, module: Module, args: TransformArgs): + def _apply_to_module( + self, module: Module, args: TransformArgs, model: PreTrainedModel + ): """ Create transforms and apply them to the module @@ -154,9 +165,25 @@ def output_hook(_, _input, output): module.register_forward_hook(output_hook) + elif args.location == TransformLocation.Q_ATTN: + initialize_hooked_attention(model, module, quantize=False) + + def query_hook(_, query_states): + return transform(query_states) + + register_query_hook(module, query_hook) + # other locations such as q_attn and k_attn have not been implemented + elif args.location == TransformLocation.K_CACHE: + initialize_hooked_kv_cache(module, quantize=False) + + def key_hook(_, key_states): + return transform(key_states) + + register_key_hook(module, key_hook) + else: - raise NotImplementedError() + assert False def _update_tied_weights(self): """ diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 42611967c..e777c96fa 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -51,7 +51,6 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = self.scheme.precision device = get_offloaded_device(module) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d3f469579..18252bbef 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -45,6 +45,12 @@ class TransformLocation(str, Enum): K_CACHE = "k_cache" Q_ATTN = "q_attn" + def is_online(self) -> bool: + return self not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) + class TransformArgs(BaseModel, use_enum_values=True): """ diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index f353f8a2e..0875ea73b 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -34,6 +34,8 @@ def get_transform_size( :param head_dim: size of head when transform is applied to mha :return: size of matrix """ + size = None + if isinstance(module, torch.nn.Linear): if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): size = module.in_features @@ -44,11 +46,13 @@ def get_transform_size( size = module.num_embeddings else: size = module.embedding_dim - else: - raise NotImplementedError(f"Transforms on {type(module)} are not supported") + elif head_dim is None: + raise NotImplementedError( + f"Transforms on {type(module)} are not supported without head_dim" + ) if head_dim is not None: - if size % head_dim != 0: + if size is not None and size % head_dim != 0: raise ValueError( f"{head_dim} must divide {size} for {type(module)} at {location}" ) @@ -105,11 +109,11 @@ def apply_transform_weight( assert transform_weight.shape[0] == transform_weight.shape[1] - if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) + if TransformLocation(location).is_online(): + return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.WEIGHT_INPUT: + if module_type == torch.nn.Linear: + if location == TransformLocation.WEIGHT_INPUT: # equivalent to (transform_weight @ value.T).T return _multihead_matmul(value, transform_weight.T) @@ -117,26 +121,14 @@ def apply_transform_weight( # equivalent to (value.T @ transform_weight).T return _multihead_matmul(transform_weight.T, value) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - # similar derivation to torch.nn.Linear, but `y = (x W)` elif module_type == torch.nn.Embedding: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) - - elif location == TransformLocation.WEIGHT_INPUT: - return _multihead_matmul( - transform_weight, - value, - ) + if location == TransformLocation.WEIGHT_INPUT: + return _multihead_matmul(transform_weight, value) elif location == TransformLocation.WEIGHT_OUTPUT: return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - raise NotImplementedError( f"Applying transforms to {module_type} {location} is not supported" ) From e7b1338035f91985fb6c405049244baf9e00e208 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Aug 2025 23:28:50 -0400 Subject: [PATCH 02/42] fix kv cache passing Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 28 ++++++++----- src/compressed_tensors/modeling/kvcache.py | 41 +++++++++++-------- .../transform/factory/base.py | 2 +- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index 299816ef7..8033da286 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -30,7 +30,7 @@ class QuantizedAttentionImpl(torch.nn.Module): def __init__(self, attn_module: torch.nn.Module): super().__init__() self.attn_module_container = [attn_module] # avoid circular reference - self.quantization_enabled = False + self._quantization_enabled = False def forward( self, @@ -38,9 +38,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, + *args, **kwargs, ): # quantization always gets applied last after hooks, in the same way that @@ -49,7 +47,8 @@ def forward( scheme: Optional[QuantizationScheme] = getattr( module, "quantization_scheme", None ) - if scheme is not None and self.quantization_enabled: + assert not self._quantization_enabled + if scheme is not None and self._quantization_enabled: if scheme.input_activations is not None: query = forward_quantize(module, query, "q", scheme.input_activations) @@ -60,9 +59,17 @@ def forward( raise NotImplementedError("") return ALL_ATTENTION_FUNCTIONS[_original_impl]( - module, query, key, value, attention_mask, scaling, dropout, **kwargs + module, + query, + key, + value, + *args, + **kwargs, ) + def enable_quantization(self): + self._quantization_enabled = True + def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): if hasattr(module, "impl"): @@ -76,7 +83,6 @@ def initialize_hooked_attention( ): if not hasattr(module, "impl"): module.register_module("impl", QuantizedAttentionImpl(module)) - if model.config._attn_implementation != "ct_hooked_attention": # assumes only one model at a time global _original_impl @@ -85,13 +91,13 @@ def initialize_hooked_attention( AttentionInterface.register("ct_hooked_attention", ct_hooked_attention) model.config._attn_implementation = "ct_hooked_attention" - if quantize: + impl: QuantizedAttentionImpl = getattr(module, "impl") + if quantize and not impl._quantization_enabled: # initialize q scale - impl: QuantizedAttentionImpl = getattr(module, "impl") - impl.quantization_enabled = True + impl.enable_quantization() - initialize_hooked_kv_cache(module, quantize=True) + initialize_hooked_kv_cache(model, module, quantize=quantize) def register_query_hook(module: torch.nn.Module, func: Callable) -> RemovableHandle: diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 573b110f5..d65a28fc4 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -23,16 +23,15 @@ ) from torch import Tensor from torch.utils.hooks import RemovableHandle -from transformers import DynamicCache +from transformers import Cache, PretrainedConfig, PreTrainedModel -class QuantizedKVCache(DynamicCache, torch.nn.Module): +class QuantizedKVCache(torch.nn.Module): def __init__(self, attn_module: torch.nn.Module): - DynamicCache.__init__(self) - torch.nn.Module.__init__(self) + super().__init__() self.attn_module_container = [attn_module] # avoid nn.Module circular reference - self.use_cache = False - self.quantization_enabled = False + self.past_key_value: Optional[Cache] = None + self._quantization_enabled = False def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: return self(*args, **kwargs) @@ -41,8 +40,8 @@ def forward( self, key_states: Tensor, value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, + *args, + **kwargs, ) -> Tuple[Tensor, Tensor]: # quantization always gets applied last after hooks, in the same way that # quantized `wrapped_forward` always applies quantization last @@ -52,7 +51,8 @@ def forward( module, "quantization_scheme", None ) - if scheme is not None and self.quantization_enabled: + assert not self._quantization_enabled + if scheme is not None and self._quantization_enabled: if scheme.input_activations is not None: key_states = forward_quantize( module, key_states, "k", scheme.input_activations @@ -67,31 +67,36 @@ def forward( if scheme.output_activations is not None: raise NotImplementedError("") - if self.use_cache: - return super().update(key_states, value_states, layer_idx, cache_kwargs) + if self.past_key_value is not None: + ret = self.past_key_value.update(key_states, value_states, *args, **kwargs) + self.past_key_value = None + return ret else: return key_states, value_states + def enable_quantization(self): + self._quantization_enabled = True -def initialize_hooked_kv_cache(module: torch.nn.Module, quantize: bool = False): + +def initialize_hooked_kv_cache( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False +): if not hasattr(module, "kv_cache"): module.register_module("kv_cache", QuantizedKVCache(module)) module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) - if quantize: + kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + if quantize and not kv_cache._quantization_enabled: # initialize k scale # initialize v scale - kv_cache: QuantizedKVCache = getattr(module, "kv_cache") - kv_cache.quantization_enabled = True + kv_cache.enable_quantization() def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + kv_cache.past_key_value = kwargs.get("past_key_value", None) kwargs["past_key_value"] = kv_cache - # use cache if cache is enabled, but this is typically not used during calibration - kv_cache.use_cache = kwargs.get("use_cache", False) - return args, kwargs diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 5f9b7d082..9843a2549 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -175,7 +175,7 @@ def query_hook(_, query_states): # other locations such as q_attn and k_attn have not been implemented elif args.location == TransformLocation.K_CACHE: - initialize_hooked_kv_cache(module, quantize=False) + initialize_hooked_kv_cache(model, module, quantize=False) def key_hook(_, key_states): return transform(key_states) From 36d7f2cf5e442b39402fc2203b258028bb334f4e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 19 Aug 2025 23:36:55 -0400 Subject: [PATCH 03/42] slightly cleaner, validated with r3 Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/transform_args.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index 18252bbef..75c816492 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -76,9 +76,6 @@ def wrap_singleton(cls, value): return value def is_online(self) -> bool: - return self.location not in ( - TransformLocation.WEIGHT_INPUT, - TransformLocation.WEIGHT_OUTPUT, - ) + return TransformLocation(self.location).is_online() model_config = ConfigDict(extra="forbid") From e64dbd2c52dc5b533116b20d4471bc8cb9ba7002 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 00:00:55 -0400 Subject: [PATCH 04/42] qparam initialization Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 66 ++++++++------ src/compressed_tensors/modeling/kvcache.py | 87 ++++++++++--------- .../transform/factory/base.py | 2 +- 3 files changed, 90 insertions(+), 65 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index 8033da286..f80b8facb 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -13,16 +13,24 @@ # limitations under the License. import inspect -from typing import Callable, Dict, Optional +from typing import Callable, Optional import torch from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache from compressed_tensors.quantization import QuantizationScheme, forward_quantize +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain from torch.utils.hooks import RemovableHandle from transformers import AttentionInterface, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"] + + +IMPL_ATTR = "impl" _original_impl = "eager" # mutable @@ -30,7 +38,7 @@ class QuantizedAttentionImpl(torch.nn.Module): def __init__(self, attn_module: torch.nn.Module): super().__init__() self.attn_module_container = [attn_module] # avoid circular reference - self._quantization_enabled = False + self._qparams_initialized = False def forward( self, @@ -44,19 +52,16 @@ def forward( # quantization always gets applied last after hooks, in the same way that # quantized `wrapped_forward` always applies quantization last # because it does not use hooks - scheme: Optional[QuantizationScheme] = getattr( - module, "quantization_scheme", None + quant_args: Optional[QuantizationScheme] = getattr_chain( + module, "quantization_scheme.input_activations", None + ) + quant_enabled: Optional[QuantizationScheme] = getattr( + module, "quantization_enabled", True ) - assert not self._quantization_enabled - if scheme is not None and self._quantization_enabled: - if scheme.input_activations is not None: - query = forward_quantize(module, query, "q", scheme.input_activations) - - if scheme.weights is not None: - raise ValueError("") - if scheme.output_activations is not None: - raise NotImplementedError("") + # apply quantization if applicable + if quant_args is not None and quant_enabled and self._qparams_initialized: + query = forward_quantize(module, query, "q", quant_args) return ALL_ATTENTION_FUNCTIONS[_original_impl]( module, @@ -67,12 +72,25 @@ def forward( **kwargs, ) - def enable_quantization(self): - self._quantization_enabled = True + def initialize_qparams_once(self, module: torch.nn.Module): + assert module is self.attn_module_container[0] + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + + if not self._qparams_initialized and scheme.input_activations is not None: + _initialize_scale_zero_point(module, "q", scheme.input_activations) + self._qparams_initialized = True + + if scheme.weights is not None: + raise ValueError("") + + if scheme.output_activations is not None: + raise NotImplementedError("") def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): - if hasattr(module, "impl"): + if hasattr(module, IMPL_ATTR): return module.impl(module, *args, **kwargs) else: return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) @@ -81,8 +99,8 @@ def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): def initialize_hooked_attention( model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True ): - if not hasattr(module, "impl"): - module.register_module("impl", QuantizedAttentionImpl(module)) + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module)) if model.config._attn_implementation != "ct_hooked_attention": # assumes only one model at a time global _original_impl @@ -91,13 +109,11 @@ def initialize_hooked_attention( AttentionInterface.register("ct_hooked_attention", ct_hooked_attention) model.config._attn_implementation = "ct_hooked_attention" - impl: QuantizedAttentionImpl = getattr(module, "impl") - if quantize and not impl._quantization_enabled: - # initialize q scale - - impl.enable_quantization() + impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) + if quantize: + impl.initialize_qparams_once(module) - initialize_hooked_kv_cache(model, module, quantize=quantize) + initialize_hooked_kv_cache(module, quantize=quantize) def register_query_hook(module: torch.nn.Module, func: Callable) -> RemovableHandle: @@ -105,7 +121,7 @@ def register_query_hook(module: torch.nn.Module, func: Callable) -> RemovableHan Registers a forward pre-hook on `module.impl` that replaces the `query` argument with `func(mod, query)` (handles both positional and keyword forms). """ - impl = getattr(module, "impl") + impl = getattr(module, IMPL_ATTR) def _hook(mod: torch.nn.Module, args, kwargs): # Keyword case diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index d65a28fc4..8683fd8e1 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -13,17 +13,23 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Callable, Optional, Tuple import torch -from compressed_tensors.quantization import ( - KVCacheScaleType, - QuantizationScheme, - forward_quantize, +from compressed_tensors.quantization import QuantizationScheme, forward_quantize +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, ) +from compressed_tensors.utils import getattr_chain from torch import Tensor from torch.utils.hooks import RemovableHandle -from transformers import Cache, PretrainedConfig, PreTrainedModel +from transformers import Cache + + +__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"] + + +KV_CACHE_ATTR = "kv_cache" class QuantizedKVCache(torch.nn.Module): @@ -31,7 +37,7 @@ def __init__(self, attn_module: torch.nn.Module): super().__init__() self.attn_module_container = [attn_module] # avoid nn.Module circular reference self.past_key_value: Optional[Cache] = None - self._quantization_enabled = False + self._qparams_initialized = False def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: return self(*args, **kwargs) @@ -47,26 +53,19 @@ def forward( # quantized `wrapped_forward` always applies quantization last # because it does not use hooks module = self.attn_module_container[0] - scheme: Optional[QuantizationScheme] = getattr( - module, "quantization_scheme", None + quant_args: Optional[QuantizationScheme] = getattr_chain( + module, "quantization_scheme.input_activations", None + ) + quant_enabled: Optional[QuantizationScheme] = getattr( + module, "quantization_enabled", True ) - assert not self._quantization_enabled - if scheme is not None and self._quantization_enabled: - if scheme.input_activations is not None: - key_states = forward_quantize( - module, key_states, "k", scheme.input_activations - ) - value_states = forward_quantize( - module, value_states, "v", scheme.input_activations - ) - - if scheme.weights is not None: - raise ValueError("") - - if scheme.output_activations is not None: - raise NotImplementedError("") + # apply quantization if applicable + if quant_args is not None and quant_enabled and self._qparams_initialized: + key_states = forward_quantize(module, key_states, "k", quant_args) + value_states = forward_quantize(module, value_states, "v", quant_args) + # use existing cache from `kv_cache_attention_hook` if applicable if self.past_key_value is not None: ret = self.past_key_value.update(key_states, value_states, *args, **kwargs) self.past_key_value = None @@ -74,26 +73,36 @@ def forward( else: return key_states, value_states - def enable_quantization(self): - self._quantization_enabled = True + def initialize_qparams_once(self, module: torch.nn.Module): + assert module is self.attn_module_container[0] + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + + if not self._qparams_initialized and scheme.input_activations is not None: + _initialize_scale_zero_point(module, "k", scheme.input_activations) + _initialize_scale_zero_point(module, "v", scheme.input_activations) + self._qparams_initialized = True + + if scheme.weights is not None: + raise ValueError("") + + if scheme.output_activations is not None: + raise NotImplementedError("") -def initialize_hooked_kv_cache( - model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False -): - if not hasattr(module, "kv_cache"): - module.register_module("kv_cache", QuantizedKVCache(module)) +def initialize_hooked_kv_cache(module: torch.nn.Module, quantize: bool = False): + if not hasattr(module, KV_CACHE_ATTR): + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module)) module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) - kv_cache: QuantizedKVCache = getattr(module, "kv_cache") - if quantize and not kv_cache._quantization_enabled: - # initialize k scale - # initialize v scale - kv_cache.enable_quantization() + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + if quantize: + kv_cache.initialize_qparams_once(module) def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): - kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) kv_cache.past_key_value = kwargs.get("past_key_value", None) kwargs["past_key_value"] = kv_cache @@ -101,7 +110,7 @@ def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: - kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) def _hook(mod: torch.nn.Module, args, kwargs): # If passed as keyword, this is easy. @@ -133,7 +142,7 @@ def _hook(mod: torch.nn.Module, args, kwargs): def register_value_hook( module: torch.nn.Module, func: Callable, **kwargs ) -> RemovableHandle: - kv_cache: QuantizedKVCache = getattr(module, "kv_cache") + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) def hook(module: torch.nn.Module, args, kwargs): signature = inspect.signature(module.forward) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 9843a2549..5f9b7d082 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -175,7 +175,7 @@ def query_hook(_, query_states): # other locations such as q_attn and k_attn have not been implemented elif args.location == TransformLocation.K_CACHE: - initialize_hooked_kv_cache(model, module, quantize=False) + initialize_hooked_kv_cache(module, quantize=False) def key_hook(_, key_states): return transform(key_states) From 1a01dc37d0a99480b9230f66b2b2618a4527ce5a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 00:05:11 -0400 Subject: [PATCH 05/42] add markers Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 6 ++++++ src/compressed_tensors/modeling/kvcache.py | 19 ++++++------------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index f80b8facb..d2c2c141d 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -89,6 +89,9 @@ def initialize_qparams_once(self, module: torch.nn.Module): raise NotImplementedError("") +# ----- initialize ----- # + + def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): if hasattr(module, IMPL_ATTR): return module.impl(module, *args, **kwargs) @@ -116,6 +119,9 @@ def initialize_hooked_attention( initialize_hooked_kv_cache(module, quantize=quantize) +# ----- hooks ----- # + + def register_query_hook(module: torch.nn.Module, func: Callable) -> RemovableHandle: """ Registers a forward pre-hook on `module.impl` that replaces the `query` argument diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 8683fd8e1..8c77b5702 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -91,6 +91,9 @@ def initialize_qparams_once(self, module: torch.nn.Module): raise NotImplementedError("") +# ----- initialize ----- # + + def initialize_hooked_kv_cache(module: torch.nn.Module, quantize: bool = False): if not hasattr(module, KV_CACHE_ATTR): module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module)) @@ -109,6 +112,9 @@ def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): return args, kwargs +# ----- hooks ----- # + + def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) @@ -137,16 +143,3 @@ def _hook(mod: torch.nn.Module, args, kwargs): return args, kwargs return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) - - -def register_value_hook( - module: torch.nn.Module, func: Callable, **kwargs -) -> RemovableHandle: - kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) - - def hook(module: torch.nn.Module, args, kwargs): - signature = inspect.signature(module.forward) - bound_args = signature.bind_partial(*args, **kwargs) - return func(module, bound_args.arguments["value_states"]) - - return kv_cache.register_forward_pre_hook(hook, with_kwargs=True) From 585550bd67ea660f500fd70e800fcd54f1b813cc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 00:15:52 -0400 Subject: [PATCH 06/42] cleanup Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 16 ++++--------- src/compressed_tensors/modeling/kvcache.py | 24 ++++++++------------ 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index d2c2c141d..f393a8adb 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -49,20 +49,14 @@ def forward( *args, **kwargs, ): - # quantization always gets applied last after hooks, in the same way that - # quantized `wrapped_forward` always applies quantization last - # because it does not use hooks - quant_args: Optional[QuantizationScheme] = getattr_chain( - module, "quantization_scheme.input_activations", None - ) - quant_enabled: Optional[QuantizationScheme] = getattr( - module, "quantization_enabled", True - ) - - # apply quantization if applicable + # quantization + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) if quant_args is not None and quant_enabled and self._qparams_initialized: query = forward_quantize(module, query, "q", quant_args) + # original attention return ALL_ATTENTION_FUNCTIONS[_original_impl]( module, query, diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 8c77b5702..dfb75462e 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -49,29 +49,23 @@ def forward( *args, **kwargs, ) -> Tuple[Tensor, Tensor]: - # quantization always gets applied last after hooks, in the same way that - # quantized `wrapped_forward` always applies quantization last - # because it does not use hooks + # quantization module = self.attn_module_container[0] - quant_args: Optional[QuantizationScheme] = getattr_chain( - module, "quantization_scheme.input_activations", None - ) - quant_enabled: Optional[QuantizationScheme] = getattr( - module, "quantization_enabled", True - ) - - # apply quantization if applicable + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) if quant_args is not None and quant_enabled and self._qparams_initialized: key_states = forward_quantize(module, key_states, "k", quant_args) value_states = forward_quantize(module, value_states, "v", quant_args) - # use existing cache from `kv_cache_attention_hook` if applicable + # original cache if self.past_key_value is not None: ret = self.past_key_value.update(key_states, value_states, *args, **kwargs) - self.past_key_value = None - return ret else: - return key_states, value_states + ret = (key_states, value_states) + + self.past_key_value = None + return ret def initialize_qparams_once(self, module: torch.nn.Module): assert module is self.attn_module_container[0] From f49524a162522dad0810a9289393f14fc6c294ad Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 12:42:14 -0400 Subject: [PATCH 07/42] add narrow match Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/match.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 30b85fc5e..e10454c6a 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -256,3 +256,13 @@ def _match_class(module: torch.nn.Module, target: str) -> bool: ) for cls in module.__class__.__mro__ ) + + +def is_narrow_match(model: torch.nn.Module, targets: Iterable[str], name: str) -> bool: + module = model.get_submodule(name) + parent_name = name.rsplit(".", 1)[0] + parent = model.get_submodule(parent_name) + + return is_match(name, module, targets) and not is_match( + parent_name, parent, targets + ) From 73743124ef0a0678d8ca11a932d2295fe4290001 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 12:43:39 -0400 Subject: [PATCH 08/42] better quant matching Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 120 +++++++++--------- .../quantization/lifecycle/initialize.py | 75 ++++++----- src/compressed_tensors/utils/match.py | 39 ++++++ 3 files changed, 134 insertions(+), 100 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 71bc05f8f..2daeaecf6 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -14,32 +14,36 @@ import logging import re +from collections import OrderedDict from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType from typing import Set, Union import torch -from compressed_tensors.quantization.lifecycle.compressed import ( - compress_quantized_weights, -) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.quant_config import QuantizationConfig, QuantizationStatus + from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, + is_attention_module ) -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.quant_config import ( - QuantizationConfig, - QuantizationStatus, -) -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.compressed import compress_quantized_weights from compressed_tensors.quantization.utils import ( KV_CACHE_TARGETS, infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils import match_named_modules, replace_module -from compressed_tensors.utils.offload import update_parameter_data -from compressed_tensors.utils.safetensors_load import get_safetensors_folder +from compressed_tensors.utils import ( + get_safetensors_folder, + is_narrow_match, + match_named_modules, + match_targets, + replace_module, + update_parameter_data, +) from safetensors import safe_open from torch.nn import Module from transformers import PreTrainedModel @@ -133,58 +137,60 @@ def apply_quantization_config( return dict() config = deepcopy(config) - # preprocessing for kv cache quantization - # TODO: KV cache-only uses this, attention uses standard targets - # perhaps the kv_cache targets have their own matching loop - config = process_quantization_config(config) - - for scheme in config.config_groups.values(): - for name, module in match_named_modules(model, scheme.targets): - # apply status - setattr(module, "quantization_status", config.quantization_status) - - if isinstance(module, torch.nn.Linear): - # can remove after meta model compression lands - force_zero_point_init = ( - config.quantization_status != QuantizationStatus.COMPRESSED - ) - scale_dtype = None - if config.quantization_status == QuantizationStatus.FROZEN: - if hasattr(model, "dtype"): - scale_dtype = model.dtype - - # add quantization parameters, wrap forward - setattr(module, "quantization_scheme", scheme) - initialize_module_for_quantization( - module, - force_zero_point=force_zero_point_init, - scale_dtype=scale_dtype, - ) + # build mapping of targets to schemes for easier matching + # use ordered dict to preserve target ordering in config + target_to_scheme = { + target: scheme + for scheme in config.config_groups.values() + for target in scheme.targets + } + names_to_scheme = dict() - # hopefully we can remove this soon - # avoid circular dep - from compressed_tensors.linear.compressed_linear import CompressedLinear + # preprocessing for kv cache, TODO: fix + config = process_quantization_config(config) - if run_compressed: + if run_compressed: + from compressed_tensors.linear.compressed_linear import CompressedLinear + + # mark appropriate layers for quantization by setting their quantization schemes + for name, submodule in match_named_modules( + model, target_to_scheme, config.ignore or [], warn_on_fail=True + ): + # mark modules to be quantized by adding + # quant scheme to the matching layers + matched_targets = match_targets(name, submodule, target_to_scheme) + scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) + + if isinstance(submodule, torch.nn.Linear): + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + # TODO: expand to more module types compressed_linear = CompressedLinear.from_linear( - module, + submodule, quantization_scheme=scheme, quantization_format=format, ) replace_module(model, name, compressed_linear) - elif name.endswith("self_attn"): - # avoid circular dep - from compressed_tensors.modeling.attention import ( - initialize_hooked_attention, - ) + # target matched - add layer and scheme to target list + setattr(submodule, "quantization_scheme", scheme) + names_to_scheme[name] = submodule.quantization_scheme - initialize_hooked_attention(model, module, quantize=True) + elif is_attention_module(submodule) and is_narrow_match( + model, matched_targets, name + ): + # unlike linear, do initialization here + from compressed_tensors.modeling.attention import initialize_hooked_attention + initialize_hooked_attention(model, submodule, quantize=True) - else: - raise ValueError(f"Cannot quantize unknown module type {type(module)}") + # target matched - add layer and scheme to target list + setattr(submodule, "quantization_scheme", scheme) + names_to_scheme[name] = submodule.quantization_scheme - return {} # hopefully can remove soon + # apply current quantization status across all targeted layers + apply_quantization_status(model, config.quantization_status) + return names_to_scheme def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: @@ -349,14 +355,6 @@ def _find_matches( return matches -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def _load_quant_args_from_mapping( base_name: str, module_name: str, module: Module, mapping: Dict ): diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b0c32439f..e907b02c7 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,7 @@ import math import warnings from enum import Enum -from typing import List, Optional +from typing import Optional import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -82,54 +82,51 @@ def initialize_module_for_quantization( # no scheme passed and layer not targeted for quantization - skip return - if is_attention_module(module): - # quantized actions based on calltime status - _initialize_attn_scales(module) + if not isinstance(module, torch.nn.Linear): + return - else: + if scheme.input_activations is not None: + _initialize_scale_zero_point( + module, + "input", + scheme.input_activations, + force_zero_point=force_zero_point, + scale_dtype=scale_dtype, + ) - if scheme.input_activations is not None: + if scheme.weights is not None: + if hasattr(module, "weight"): + weight_shape = None + if isinstance(module, torch.nn.Linear): + weight_shape = module.weight.shape _initialize_scale_zero_point( module, - "input", - scheme.input_activations, + "weight", + scheme.weights, + weight_shape=weight_shape, force_zero_point=force_zero_point, scale_dtype=scale_dtype, ) + else: + _LOGGER.warning( + f"module type {type(module)} targeted for weight quantization but " + "has no attribute weight, skipping weight quantization " + f"for {type(module)}" + ) - if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - scale_dtype=scale_dtype, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations, scale_dtype=scale_dtype - ) + if scheme.output_activations is not None: + if not is_kv_cache_quant_scheme(scheme): + _initialize_scale_zero_point( + module, "output", scheme.output_activations, scale_dtype=scale_dtype + ) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED - with disable_hf_hook(module): - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index e10454c6a..5488a09e0 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -27,8 +27,10 @@ __all__ = [ "match_named_modules", "match_named_parameters", + "match_targets", "match_modules_set", "is_match", + "is_narrow_match", ] @@ -62,6 +64,7 @@ def match_named_modules( if not is_match(name, module, ignore, fused=fused): yield name, module + break if warn_on_fail: for target in unmatched_targets: @@ -110,6 +113,42 @@ def match_named_parameters( ) +def match_targets( + name: str, module: torch.nn.Module, targets: Iterable[str] | None = None +) -> List[str]: + """ + Returns the targets that match the given name and module. + :param name: the name of the module + :param module: the module to match + :param targets: the target strings, potentially containing "re:" prefixes + :return: the targets that match the given name and module + Outputs are ordered by type: exact name match, regex name match, class name match + """ + targets = targets or [] + + if isinstance(module, InternalModule): + return [] + + # The order of the output `matches` list matters, the are arranged from most + # specific to least specific, and this order will be used when merging configs. + # The entries are sorted in the following order: + # 1. matches on exact strings + # 2. matches on regex patterns + # 3. matches on module names + + targets = sorted(targets, key=lambda x: ("re:" in x, x)) + matched_targets = [] + for target in targets: + if _match_name(name, target): + matched_targets.append(target) + + for target in targets: + if _match_class(module, target) and target not in matched_targets: + matched_targets.append(target) + + return matched_targets + + def match_modules_set( model: torch.nn.Module, targets: Iterable[str], From 8941779bcb240fc21e1481bbd405b39e6f93925e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 23:00:03 -0400 Subject: [PATCH 09/42] attention and kv quantization Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 37 ++-- src/compressed_tensors/modeling/kvcache.py | 60 ++++-- .../quantization/lifecycle/apply.py | 194 ++++++++---------- .../quantization/lifecycle/initialize.py | 4 +- .../quantization/quant_args.py | 3 + .../quantization/quant_scheme.py | 2 + .../quantization/utils/helpers.py | 15 +- src/compressed_tensors/utils/match.py | 2 +- 8 files changed, 172 insertions(+), 145 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index f393a8adb..fc3c7d91c 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -17,11 +17,16 @@ import torch from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache -from compressed_tensors.quantization import QuantizationScheme, forward_quantize +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + forward_quantize, +) from compressed_tensors.quantization.lifecycle.initialize import ( _initialize_scale_zero_point, ) from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule from torch.utils.hooks import RemovableHandle from transformers import AttentionInterface, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -34,7 +39,7 @@ _original_impl = "eager" # mutable -class QuantizedAttentionImpl(torch.nn.Module): +class QuantizedAttentionImpl(InternalModule): def __init__(self, attn_module: torch.nn.Module): super().__init__() self.attn_module_container = [attn_module] # avoid circular reference @@ -71,17 +76,19 @@ def initialize_qparams_once(self, module: torch.nn.Module): scheme: Optional[QuantizationScheme] = getattr( module, "quantization_scheme", None ) + quant_args: Optional[QuantizationArgs] = getattr( + scheme, "input_activations", None + ) - if not self._qparams_initialized and scheme.input_activations is not None: - _initialize_scale_zero_point(module, "q", scheme.input_activations) + if ( + not self._qparams_initialized + and quant_args is not None + and not scheme.kv_cache_only + ): + _initialize_scale_zero_point(module, "q", quant_args) + print("attn init") self._qparams_initialized = True - if scheme.weights is not None: - raise ValueError("") - - if scheme.output_activations is not None: - raise NotImplementedError("") - # ----- initialize ----- # @@ -116,17 +123,17 @@ def initialize_hooked_attention( # ----- hooks ----- # -def register_query_hook(module: torch.nn.Module, func: Callable) -> RemovableHandle: +def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: """ Registers a forward pre-hook on `module.impl` that replaces the `query` argument - with `func(mod, query)` (handles both positional and keyword forms). + with `hook(mod, query)` (handles both positional and keyword forms). """ impl = getattr(module, IMPL_ATTR) def _hook(mod: torch.nn.Module, args, kwargs): # Keyword case if "query" in kwargs: - kwargs["query"] = func(mod, kwargs["query"]) + kwargs["query"] = hook(mod, kwargs["query"]) return args, kwargs # Positional case: find the index of `query` in impl.forward @@ -140,7 +147,9 @@ def _hook(mod: torch.nn.Module, args, kwargs): if idx < len(args): args = list(args) - args[idx] = func(mod, args[idx]) + ret = hook(module, args[idx]) + if ret is not None: + args[idx] = ret return tuple(args), kwargs # Not present explicitly (maybe defaulted) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index dfb75462e..5c37f66c4 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -21,6 +21,7 @@ _initialize_scale_zero_point, ) from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.utils.hooks import RemovableHandle from transformers import Cache @@ -32,7 +33,7 @@ KV_CACHE_ATTR = "kv_cache" -class QuantizedKVCache(torch.nn.Module): +class QuantizedKVCache(InternalModule): def __init__(self, attn_module: torch.nn.Module): super().__init__() self.attn_module_container = [attn_module] # avoid nn.Module circular reference @@ -69,21 +70,16 @@ def forward( def initialize_qparams_once(self, module: torch.nn.Module): assert module is self.attn_module_container[0] - scheme: Optional[QuantizationScheme] = getattr( - module, "quantization_scheme", None - ) - - if not self._qparams_initialized and scheme.input_activations is not None: - _initialize_scale_zero_point(module, "k", scheme.input_activations) - _initialize_scale_zero_point(module, "v", scheme.input_activations) + scheme = getattr(module, "quantization_scheme", None) + quant_args = getattr(scheme, "input_activations", None) + + print((type(module), self._qparams_initialized, quant_args)) + if not self._qparams_initialized and quant_args is not None: + _initialize_scale_zero_point(module, "k", quant_args) + _initialize_scale_zero_point(module, "v", quant_args) + print("kv init") self._qparams_initialized = True - if scheme.weights is not None: - raise ValueError("") - - if scheme.output_activations is not None: - raise NotImplementedError("") - # ----- initialize ----- # @@ -130,7 +126,41 @@ def _hook(mod: torch.nn.Module, args, kwargs): # If the position exists in args, replace it. if idx < len(args): args = list(args) - args[idx] = hook(mod, args[idx]) + ret = hook(module, args[idx]) + if ret is not None: + args[idx] = ret + return tuple(args), kwargs + + # Not present positionally and not in kwargs (maybe defaulted) — do nothing. + return args, kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(mod: torch.nn.Module, args, kwargs): + # If passed as keyword, this is easy. + if "value_states" in kwargs: + kwargs["value_states"] = hook(mod, kwargs["value_states"]) + return args, kwargs + + # Otherwise, find where value_states would be in positional args. + sig = inspect.signature(mod.forward) + param_names = tuple(sig.parameters.keys()) + try: + idx = param_names.index("value_states") + except ValueError: + # forward has no value_states parameter; do nothing + return args, kwargs + + # If the position exists in args, replace it. + if idx < len(args): + args = list(args) + ret = hook(module, args[idx]) + if ret is not None: + args[idx] = ret return tuple(args), kwargs # Not present positionally and not in kwargs (maybe defaulted) — do nothing. diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index e3e611dee..5dcb3013c 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -15,7 +15,7 @@ import logging from collections import OrderedDict from copy import deepcopy -from typing import Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType from typing import Union @@ -35,18 +35,18 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( - KV_CACHE_TARGETS, + ATTN_TARGETS, infer_quantization_status, is_kv_cache_quant_scheme, ) from compressed_tensors.utils import ( + deprecated, get_safetensors_folder, is_narrow_match, match_named_modules, match_targets, replace_module, update_parameter_data, - deprecated, ) from safetensors import safe_open from torch.nn import Module @@ -155,49 +155,64 @@ def apply_quantization_config( from compressed_tensors.linear.compressed_linear import CompressedLinear # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore or [], warn_on_fail=True - ): - # mark modules to be quantized by adding - # quant scheme to the matching layers - matched_targets = match_targets(name, submodule, target_to_scheme) - scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) - - if isinstance(submodule, (torch.nn.Linear, torch.nn.Embedding)): - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - setattr(submodule, "quantization_scheme", scheme) - names_to_scheme[name] = submodule.quantization_scheme - - elif is_attention_module(submodule) and is_narrow_match( - model, matched_targets, name + for scheme in config.config_groups.values(): + for name, submodule in match_named_modules( + model, scheme.targets, config.ignore or [], warn_on_fail=True ): - # unlike linear, do initialization here - from compressed_tensors.modeling.attention import ( - initialize_hooked_attention, - ) + if isinstance(submodule, (torch.nn.Linear, torch.nn.Embedding)): + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # attach scheme to module for later steps + _scheme = attach_scheme(submodule, scheme) + + elif is_attention_module(submodule) and is_narrow_match( + model, scheme.targets, name + ): + from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + ) - initialize_hooked_attention(model, submodule, quantize=True) + # silently throw away weight and output quantization for attention + _scheme = QuantizationScheme( + targets=scheme.targets, + input_activations=scheme.input_activations, + format=scheme.format, + kv_cache_only=scheme.kv_cache_only, + ) - # target matched - add layer and scheme to target list - setattr(submodule, "quantization_scheme", scheme) - names_to_scheme[name] = submodule.quantization_scheme + # attach scheme to module for later steps + _scheme = attach_scheme(submodule, _scheme) + + # unlike linear, do initialization here + initialize_hooked_attention(model, submodule, quantize=True) + + else: + continue + + names_to_scheme[name] = _scheme # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) return names_to_scheme +def attach_scheme(module: Module, scheme: QuantizationScheme) -> QuantizationScheme: + if existing_scheme := getattr(module, "quantization_scheme", None): + scheme = merge_schemes(existing_scheme, scheme) + + setattr(module, "quantization_scheme", scheme) + return scheme + + def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: """ Preprocess the raw QuantizationConfig @@ -211,9 +226,7 @@ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfi return config -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: +def process_kv_cache_config(config: QuantizationConfig) -> QuantizationConfig: """ Reformulate the `config.kv_cache` as a `config_group` and add it to the set of existing `config.groups` @@ -221,16 +234,14 @@ def process_kv_cache_config( :param config: the QuantizationConfig :return: the QuantizationConfig with additional "kv_cache" group """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") + _LOGGER.info(f"KV cache targets set to default value of: {ATTN_TARGETS}") - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, + scheme = QuantizationScheme( + targets=ATTN_TARGETS, + input_activations=config.kv_cache_scheme, + kv_cache_only=True, ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) + config.config_groups.update({"kv_cache": scheme}) return config @@ -334,65 +345,36 @@ def _load_quant_args_from_mapping( update_parameter_data(module, state_dict_zp, zp_name) -def _scheme_from_targets( - target_to_scheme: OrderedDictType[str, QuantizationScheme], - targets: List[str], - name: str, +def merge_schemes( + scheme_a: QuantizationScheme, scheme_b: QuantizationScheme ) -> QuantizationScheme: - if len(targets) == 1: - # if `targets` iterable contains a single element - # use it as the key - return target_to_scheme[targets[0]] - - # otherwise, we need to merge QuantizationSchemes corresponding - # to multiple targets. This is most likely because `name` module - # is being target both as an ordinary quantization target, as well - # as kv cache quantization target - schemes_to_merge = [target_to_scheme[target] for target in targets] - return _merge_schemes(schemes_to_merge, name) - + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "targets": + return value_a + value_b + + if field_name == "kv_cache_only": + # nones defer to other value + if value_a is None: + return value_b + if value_b is None: + return value_a + + # kv_cache_only=True overrides + return not ((not value_a) or (not value_b)) + + raise ValueError( + "The following fields have overlapping targets and conflicting values for" + f"{field_name}. Please modify your config to resolve this ambiguity.\n" + f"{scheme_a}\n" + f"{scheme_b}" + ) -def _merge_schemes( - schemes_to_merge: List[QuantizationScheme], name: str -) -> QuantizationScheme: - kv_cache_quantization_scheme = [ - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) - ] - if not kv_cache_quantization_scheme: - # if the schemes_to_merge do not contain any - # kv cache QuantizationScheme - # return the first scheme (the prioritized one, - # since the order of schemes_to_merge matters) - return schemes_to_merge[0] - else: - # fetch the kv cache QuantizationScheme and the highest - # priority non-kv cache QuantizationScheme and merge them - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] - quantization_scheme = [ - scheme - for scheme in schemes_to_merge - if not is_kv_cache_quant_scheme(scheme) - ][0] - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] - merged_scheme = {} - for scheme in schemes_to_merge: - scheme_dict = { - k: v for k, v in scheme.model_dump().items() if v is not None - } - # when merging multiple schemes, the final target will be - # the `name` argument - hence erase the original targets - del scheme_dict["targets"] - # make sure that schemes do not "clash" with each other - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) - if overlapping_keys: - raise ValueError( - f"The module: {name} is being modified by two clashing " - f"quantization schemes, that jointly try to override " - f"properties: {overlapping_keys}. Fix the quantization config " - "so that it is not ambiguous." - ) - merged_scheme.update(scheme_dict) + dict_a = scheme_a.model_dump() + dict_b = scheme_b.model_dump() - merged_scheme.update(targets=[name]) + assert dict_a.keys() == dict_b.keys() + merged_dump = { + key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys() + } - return QuantizationScheme(**merged_scheme) + return QuantizationScheme.model_validate(merged_dump) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d4f331d49..aebecffe5 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -209,7 +209,9 @@ def _initialize_scale_zero_point( expected_shape = 1 # 3. Identify quantization scale and zp dtype - scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype + scale_dtype = ( + scale_dtype if scale_dtype is not None else next(module.parameters()).dtype + ) if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 6c8984294..de8e9ffbb 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -354,6 +354,9 @@ def pytorch_dtype(self) -> torch.dtype: else: raise ValueError(f"Invalid quantization type {self.type}") + def is_online(self) -> bool: + return self.dynamic is True + @deprecated("QuantizationArgs.observer") def get_observer(self) -> str: return self.observer diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a6a3de6d4..01bc97c0d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -44,6 +44,7 @@ class QuantizationScheme(BaseModel): :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs :param format: CompressionFormat for the layer + TODO """ targets: List[str] @@ -51,6 +52,7 @@ class QuantizationScheme(BaseModel): input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None format: Optional[str] = None + kv_cache_only: Optional[bool] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5d28cac2c..59a894b65 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -40,7 +40,7 @@ "get_torch_bit_depth", "can_quantize", "parse_out_kv_cache_args", - "KV_CACHE_TARGETS", + "ATTN_TARGETS", "is_kv_cache_quant_scheme", "iter_named_leaf_modules", "iter_named_quantizable_modules", @@ -51,9 +51,8 @@ "is_fp4", ] -# target the self_attn layer -# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale -KV_CACHE_TARGETS = ["re:.*self_attn$"] +# note that this is a "narrow match", see quantization/apply.py +ATTN_TARGETS = ["re:.*self_attn$"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -410,17 +409,17 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. It does if all the following criteria are met: - - the scheme targets either exactly match the KV_CACHE_TARGETS - or the match KV_CACHE_TARGETS regex pattern + - the scheme targets either exactly match the ATTN_TARGETS + or the match ATTN_TARGETS regex pattern - the scheme quantizes output_activations (we want to quantize the - outputs from the KV_CACHE_TARGETS, as their correspond to the + outputs from the ATTN_TARGETS, as their correspond to the keys and values that are to be saved in the cache) :param scheme: The QuantizationScheme to investigate :return: boolean flag """ for target in scheme.targets: - if target in KV_CACHE_TARGETS: + if target in ATTN_TARGETS: return True return False diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index f20eafc44..7b504e5bc 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -120,7 +120,7 @@ def match_named_parameters( def match_targets( - name: str, module: torch.nn.Module, targets: Iterable[str] | None + name: str, module: torch.nn.Module, targets: Optional[Iterable[str]] ) -> List[str]: """ Returns the targets that match the given name and module. From c4af50815616a10125c3a0d2ed81f68d04cdd48d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 23:01:14 -0400 Subject: [PATCH 10/42] remove debug prints Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 1 - src/compressed_tensors/modeling/kvcache.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index fc3c7d91c..f6729978d 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -86,7 +86,6 @@ def initialize_qparams_once(self, module: torch.nn.Module): and not scheme.kv_cache_only ): _initialize_scale_zero_point(module, "q", quant_args) - print("attn init") self._qparams_initialized = True diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 5c37f66c4..93d788d41 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -73,11 +73,9 @@ def initialize_qparams_once(self, module: torch.nn.Module): scheme = getattr(module, "quantization_scheme", None) quant_args = getattr(scheme, "input_activations", None) - print((type(module), self._qparams_initialized, quant_args)) if not self._qparams_initialized and quant_args is not None: _initialize_scale_zero_point(module, "k", quant_args) _initialize_scale_zero_point(module, "v", quant_args) - print("kv init") self._qparams_initialized = True From 71463c7ec254cbf548df8d351540b96a9cd0094d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 20 Aug 2025 23:36:28 -0400 Subject: [PATCH 11/42] add todo for other strategies (block/group, channel, head) Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 9 ++++++--- src/compressed_tensors/modeling/kvcache.py | 14 +++++++++----- src/compressed_tensors/transform/factory/base.py | 2 +- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index f6729978d..15e24a1ae 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -20,6 +20,7 @@ from compressed_tensors.quantization import ( QuantizationArgs, QuantizationScheme, + QuantizationStrategy, forward_quantize, ) from compressed_tensors.quantization.lifecycle.initialize import ( @@ -71,7 +72,7 @@ def forward( **kwargs, ) - def initialize_qparams_once(self, module: torch.nn.Module): + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): assert module is self.attn_module_container[0] scheme: Optional[QuantizationScheme] = getattr( module, "quantization_scheme", None @@ -85,6 +86,8 @@ def initialize_qparams_once(self, module: torch.nn.Module): and quant_args is not None and not scheme.kv_cache_only ): + # TODO: use model.config.num_attention_heads to find query_size + assert quant_args.strategy == QuantizationStrategy.TENSOR _initialize_scale_zero_point(module, "q", quant_args) self._qparams_initialized = True @@ -114,9 +117,9 @@ def initialize_hooked_attention( impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) if quantize: - impl.initialize_qparams_once(module) + impl.initialize_qparams_once(model, module) - initialize_hooked_kv_cache(module, quantize=quantize) + initialize_hooked_kv_cache(model, module, quantize=quantize) # ----- hooks ----- # diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 93d788d41..008e38266 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Tuple import torch -from compressed_tensors.quantization import QuantizationScheme, forward_quantize +from compressed_tensors.quantization import QuantizationStrategy, forward_quantize from compressed_tensors.quantization.lifecycle.initialize import ( _initialize_scale_zero_point, ) @@ -24,7 +24,7 @@ from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.utils.hooks import RemovableHandle -from transformers import Cache +from transformers import Cache, PreTrainedModel __all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"] @@ -68,12 +68,14 @@ def forward( self.past_key_value = None return ret - def initialize_qparams_once(self, module: torch.nn.Module): + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): assert module is self.attn_module_container[0] scheme = getattr(module, "quantization_scheme", None) quant_args = getattr(scheme, "input_activations", None) if not self._qparams_initialized and quant_args is not None: + # TODO: use model.config.num_key_value_heads to find key_size, value_size + assert quant_args.strategy == QuantizationStrategy.TENSOR _initialize_scale_zero_point(module, "k", quant_args) _initialize_scale_zero_point(module, "v", quant_args) self._qparams_initialized = True @@ -82,14 +84,16 @@ def initialize_qparams_once(self, module: torch.nn.Module): # ----- initialize ----- # -def initialize_hooked_kv_cache(module: torch.nn.Module, quantize: bool = False): +def initialize_hooked_kv_cache( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False +): if not hasattr(module, KV_CACHE_ATTR): module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module)) module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) if quantize: - kv_cache.initialize_qparams_once(module) + kv_cache.initialize_qparams_once(model, module) def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 5f9b7d082..9843a2549 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -175,7 +175,7 @@ def query_hook(_, query_states): # other locations such as q_attn and k_attn have not been implemented elif args.location == TransformLocation.K_CACHE: - initialize_hooked_kv_cache(module, quantize=False) + initialize_hooked_kv_cache(model, module, quantize=False) def key_hook(_, key_states): return transform(key_states) From 2cfff73dbec458093ffda677cb7b6f6280905179 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 15:06:24 -0400 Subject: [PATCH 12/42] support registering to offloaded attention Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/kvcache.py | 19 +++++++++++++------ .../transform/factory/base.py | 1 + .../transform/factory/hadamard.py | 14 ++++++++++++-- .../transform/factory/matrix_multiply.py | 9 ++++++++- src/compressed_tensors/utils/offload.py | 16 +++++++++++----- 5 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 008e38266..965ffa208 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -16,12 +16,14 @@ from typing import Callable, Optional, Tuple import torch +import transformers from compressed_tensors.quantization import QuantizationStrategy, forward_quantize from compressed_tensors.quantization.lifecycle.initialize import ( _initialize_scale_zero_point, ) from compressed_tensors.utils import getattr_chain from compressed_tensors.utils.internal import InternalModule +from packaging import version from torch import Tensor from torch.utils.hooks import RemovableHandle from transformers import Cache, PreTrainedModel @@ -37,7 +39,7 @@ class QuantizedKVCache(InternalModule): def __init__(self, attn_module: torch.nn.Module): super().__init__() self.attn_module_container = [attn_module] # avoid nn.Module circular reference - self.past_key_value: Optional[Cache] = None + self.past_key_values: Optional[Cache] = None self._qparams_initialized = False def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: @@ -60,12 +62,12 @@ def forward( value_states = forward_quantize(module, value_states, "v", quant_args) # original cache - if self.past_key_value is not None: - ret = self.past_key_value.update(key_states, value_states, *args, **kwargs) + if self.past_key_values is not None: + ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) else: ret = (key_states, value_states) - self.past_key_value = None + self.past_key_values = None return ret def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): @@ -98,8 +100,13 @@ def initialize_hooked_kv_cache( def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) - kv_cache.past_key_value = kwargs.get("past_key_value", None) - kwargs["past_key_value"] = kv_cache + _past_kv_name = ( + "past_key_value" + if version.parse(transformers.__version__) <= version.parse("4.55.4") + else "past_key_values" # transformers#39956 + ) + kv_cache.past_key_values = kwargs.get(_past_kv_name, None) + kwargs[_past_kv_name] = kv_cache return args, kwargs diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 9843a2549..f612ded6a 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -125,6 +125,7 @@ def _apply_to_module( transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) self.transforms.append(transform) + register_offload_module(module, transform_name, transform) # register input transformation hook diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index e777c96fa..1e0baa7f2 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -22,8 +22,12 @@ apply_transform_weight, get_transform_size, ) -from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict +from compressed_tensors.utils.offload import ( + get_execution_device, + get_offloaded_device, + has_offloaded_params, +) from torch import Tensor, device, dtype from torch.nn import Module, Parameter @@ -53,9 +57,15 @@ def create_transform(self, module: Module, args: TransformArgs): """ size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = self.scheme.precision - device = get_offloaded_device(module) exec_device = get_execution_device(module) + # if the parent is offloaded, then weight will be placed in the weights_map + # if the parent is not offloaded, then the weight will stay on the exec device + if has_offloaded_params(module): + device = get_offloaded_device(module) + else: + device = exec_device + factory_kwargs = {"construct_device": exec_device} weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index a7112e769..54bf7dcc1 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -53,7 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs): assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = self.scheme.precision - device = get_offloaded_device(module) + exec_device = get_execution_device(module) + + # if the parent is offloaded, then weight will be placed in the weights_map + # if the parent is not offloaded, then the weight will stay on the exec device + if has_offloaded_params(module): + device = get_offloaded_device(module) + else: + device = exec_device weight = self.weights[size, dtype, device] if args.inverse: diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index eb794915c..a6b79d9b6 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -127,11 +127,17 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device: :param module: module to check :return: device module is offloaded to onto after forward pass """ - if has_offloaded_params(module): - first_key = list(module._hf_hook.weights_map.keys())[0] - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - return next(module.parameters()).device + for submodule in module.modules(): + name, param = next(submodule.named_parameters(recurse=False), (None, None)) + if has_offloaded_params(submodule) and name is not None: + return cast_to_device(submodule._hf_hook.weights_map[name].device) + + if param is not None: + assert param.device != torch.device("meta") + return param.device + + warnings.warn(f"Unable to get offload device of {module}, falling back to CPU") + return torch.device("cpu") @check_accelerate(fallback=None) From 53aa503253a208db5dbb16f81316fe08beb2ca1a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 18:30:55 -0400 Subject: [PATCH 13/42] better merging and serialization logic Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 9 +- .../quantization/lifecycle/apply.py | 112 +++--------- .../quantization/quant_config.py | 170 ++++++++---------- .../quantization/quant_scheme.py | 40 ++++- tests/test_quantization/test_quant_config.py | 4 +- 5 files changed, 151 insertions(+), 184 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 3d94348a0..be3842384 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -703,9 +703,12 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): - names_to_scheme = apply_quantization_config( - model, self.quantization_config - ) + apply_quantization_config(model, self.quantization_config) + names_to_scheme: Set[QuantizationScheme] = { + name: getattr(module, "quantization_scheme") + for name, module in model.named_modules() + if getattr(module, "quantization_scheme", None) is not None + } # Load activation scales/zp or any other quantization parameters # Conditionally load the weight quantization parameters if we have a dense compressor # Or if a sparsity compressor has already been applied diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 5dcb3013c..c8c5a4e1d 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -13,11 +13,8 @@ # limitations under the License. import logging -from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from compressed_tensors.config import CompressionFormat @@ -28,7 +25,6 @@ initialize_module_for_quantization, is_attention_module, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, @@ -37,7 +33,6 @@ from compressed_tensors.quantization.utils import ( ATTN_TARGETS, infer_quantization_status, - is_kv_cache_quant_scheme, ) from compressed_tensors.utils import ( deprecated, @@ -124,7 +119,7 @@ def apply_quantization_config( model: PreTrainedModel, config: Union[QuantizationConfig, None], run_compressed: bool = False, -) -> Dict[str, QuantizationScheme]: +): """ Initializes the model for quantization in-place based on the given config. Optionally coverts quantizable modules to compressed_linear modules @@ -134,32 +129,27 @@ def apply_quantization_config( :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load """ - # Workaround for when HF Quantizer passes None, see PR #180 - if config is None: - return dict() - config = deepcopy(config) + from compressed_tensors.linear.compressed_linear import CompressedLinear + from compressed_tensors.modeling.attention import initialize_hooked_attention - # build mapping of targets to schemes for easier matching - # use ordered dict to preserve target ordering in config - target_to_scheme = { - target: scheme - for scheme in config.config_groups.values() - for target in scheme.targets - } - names_to_scheme = dict() + config = deepcopy(config) + if config is None: # see PR #180 + return dict() - # preprocessing for kv cache, TODO: fix + # preprocess to support kv cache scheme config = process_quantization_config(config) - if run_compressed: - from compressed_tensors.linear.compressed_linear import CompressedLinear - # mark appropriate layers for quantization by setting their quantization schemes for scheme in config.config_groups.values(): for name, submodule in match_named_modules( model, scheme.targets, config.ignore or [], warn_on_fail=True ): - if isinstance(submodule, (torch.nn.Linear, torch.nn.Embedding)): + # attach scheme to module (with merging) + attach_scheme(submodule, scheme) + + # replace with run compressed if applicable + # FUTURE: move this to model compressor + if isinstance(submodule, torch.nn.Linear): if run_compressed: format = config.format if format != CompressionFormat.dense.value: @@ -171,43 +161,26 @@ def apply_quantization_config( ) replace_module(model, name, compressed_linear) - # attach scheme to module for later steps - _scheme = attach_scheme(submodule, scheme) - - elif is_attention_module(submodule) and is_narrow_match( + # replace attention implementation and kvcache with hookable modules + if is_attention_module(submodule) and is_narrow_match( model, scheme.targets, name ): - from compressed_tensors.modeling.attention import ( - initialize_hooked_attention, - ) - - # silently throw away weight and output quantization for attention - _scheme = QuantizationScheme( - targets=scheme.targets, - input_activations=scheme.input_activations, - format=scheme.format, - kv_cache_only=scheme.kv_cache_only, - ) - - # attach scheme to module for later steps - _scheme = attach_scheme(submodule, _scheme) - - # unlike linear, do initialization here + # unlike linear, do qparam initialization here initialize_hooked_attention(model, submodule, quantize=True) - else: - continue - - names_to_scheme[name] = _scheme - - # apply current quantization status across all targeted layers + # apply current quantization status across all targeted linear/embedding layers apply_quantization_status(model, config.quantization_status) - return names_to_scheme + + # attach for serialization + # do merginging using from_pretrained + if existing_config := getattr(model, "quantization_config", None): + config = config.merge(existing_config) + setattr(model, "quantization_config", config) def attach_scheme(module: Module, scheme: QuantizationScheme) -> QuantizationScheme: if existing_scheme := getattr(module, "quantization_scheme", None): - scheme = merge_schemes(existing_scheme, scheme) + scheme = scheme.merge(existing_scheme) setattr(module, "quantization_scheme", scheme) return scheme @@ -343,38 +316,3 @@ def _load_quant_args_from_mapping( state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") update_parameter_data(module, state_dict_zp, zp_name) - - -def merge_schemes( - scheme_a: QuantizationScheme, scheme_b: QuantizationScheme -) -> QuantizationScheme: - def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: - if field_name == "targets": - return value_a + value_b - - if field_name == "kv_cache_only": - # nones defer to other value - if value_a is None: - return value_b - if value_b is None: - return value_a - - # kv_cache_only=True overrides - return not ((not value_a) or (not value_b)) - - raise ValueError( - "The following fields have overlapping targets and conflicting values for" - f"{field_name}. Please modify your config to resolve this ambiguity.\n" - f"{scheme_a}\n" - f"{scheme_b}" - ) - - dict_a = scheme_a.model_dump() - dict_b = scheme_b.model_dump() - - assert dict_a.keys() == dict_b.keys() - merged_dump = { - key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys() - } - - return QuantizationScheme.model_validate(merged_dump) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 42df3a337..6f74109fb 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -26,7 +26,7 @@ module_type, parse_out_kv_cache_args, ) -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from torch.nn import Module @@ -35,7 +35,6 @@ "QuantizationConfig", "LIFECYCLE_ORDER", "DEFAULT_QUANTIZATION_METHOD", - "DEFAULT_QUANTIZATION_FORMAT", ] @@ -58,16 +57,11 @@ class QuantizationStatus(str, Enum): FROZEN = "frozen" COMPRESSED = "compressed" - @classmethod - def lifecycle_order(cls) -> List["QuantizationStatus"]: - """ - :return: list of correct quantization lifecycle order - """ - return - def __ge__(self, other): if other is None: return True + if isinstance(other, str): + other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other) @@ -75,6 +69,8 @@ def __ge__(self, other): def __gt__(self, other): if other is None: return True + if isinstance(other, str): + other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other) @@ -82,6 +78,8 @@ def __gt__(self, other): def __lt__(self, other): if other is None: return False + if isinstance(other, str): + other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other) @@ -89,6 +87,8 @@ def __lt__(self, other): def __le__(self, other): if other is None: return False + if isinstance(other, str): + other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other) @@ -102,7 +102,6 @@ def __le__(self, other): ] DEFAULT_QUANTIZATION_METHOD = "compressed-tensors" -DEFAULT_QUANTIZATION_FORMAT = "fakequant" class QuantizationConfig(BaseModel): @@ -138,7 +137,7 @@ class QuantizationConfig(BaseModel): config_groups: Dict[str, Union[QuantizationScheme, List[str]]] quant_method: str = DEFAULT_QUANTIZATION_METHOD kv_cache_scheme: Optional[QuantizationArgs] = None - format: str = DEFAULT_QUANTIZATION_FORMAT + format: str = CompressionFormat.dense.value quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) @@ -159,97 +158,86 @@ def model_post_init(self, __context): targets=targets_or_scheme, ) + @field_validator("format", mode="before") + def validate_format(cls, value: Any) -> str: + if value is None: + return CompressionFormat.dense.value + + if isinstance(value, list): + if len(value) == 0: + return CompressionFormat.dense.value + + if len(value) == 1: + assert isinstance(value[0], str) + return CompressionFormat(value[0]).value + + else: + return CompressionFormat.mixed_precision.value + + if isinstance(value, str): + return CompressionFormat(value).value + + return str(value) + def to_dict(self): # for compatibility with HFQuantizer return self.model_dump() - @staticmethod - def from_pretrained( - model: Module, format: Optional[str] = None - ) -> Optional["QuantizationConfig"]: - """ - Converts a model into its associated QuantizationConfig based on the - QuantizationScheme attached to each quantized module + def merge(self, other: "QuantizationConfig") -> "QuantizationConfig": + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "config_groups": + return value_a + value_b - :param model: model to calculate quantization scheme of - :return: filled out QuantizationScheme for the input model - """ - quant_scheme_to_layers = [] - quantization_status = None - ignore = {} - quantization_type_names = set() - for name, submodule in model.named_modules(): - layer_type = module_type(submodule) - if not is_module_quantized(submodule): - if layer_type not in ignore: - ignore[layer_type] = [] - ignore[layer_type].append(name) - else: - quantization_status = submodule.quantization_status - scheme = submodule.quantization_scheme - quantization_type_names.add(layer_type) - - match_found = False - for existing_scheme in quant_scheme_to_layers: - if scheme == existing_scheme: - match_found = True - break - if not match_found: - quant_scheme_to_layers.append(scheme) - - if len(quant_scheme_to_layers) == 0: # No quantized layers - return None - - # kv-cache only, no weight/activation quantization - if ( - len(quantization_type_names) == 1 - and "attention" in list(quantization_type_names)[0].lower() - ): - quantization_type_names.add("Linear") - - # clean up ignore list, we can leave out layers types if none of the - # instances are quantized - consolidated_ignore = [] - for layer_type, ignore_names in ignore.items(): - if layer_type in quantization_type_names: - # specific layers of a quantized type are ignored - consolidated_ignore += ignore_names - # else we leave it off the ignore list, doesn't fall under any of the - # existing quantization schemes so it won't be quantized - - kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( - quant_scheme_to_layers - ) - kv_cache_scheme = ( - kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args - ) + if field_name == "ignore": + if value_a is not None and value_b is None: + return value_a - config_groups = {} - for idx, scheme in enumerate(quant_scheme_to_layers): - group_name = "group_" + str(idx) - config_groups[group_name] = scheme + if value_a is None and value_b is not None: + return value_b - if format is None: - if quantization_status == QuantizationStatus.COMPRESSED: - format = CompressionFormat.int_quantized.value - else: - format = CompressionFormat.dense.value - elif isinstance(format, list): - format = ( - CompressionFormat.mixed_precision.value - if len(format) > 1 - else format[0] + if set(value_a) == set(value_b): + return value_a + + raise NotImplementedError( + "Cannot merge quantization configs with non-identical ignore lists " + "Please modify your config to resolve this ambiguity." + f"\n{self}\n{other}" + ) + + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if value_a == value_b: + return value_a + + raise ValueError( + "The following fields have overlapping targets and conflicting values " + f"for {field_name}. Please modify your config to resolve this " + f"ambiguity.\n{self}\n{other}" ) - return QuantizationConfig( - config_groups=config_groups, - quantization_status=quantization_status, - kv_cache_scheme=kv_cache_scheme, - global_compression_ratio=None, - format=format, - ignore=consolidated_ignore, + dict_a = self.model_dump() + dict_b = other.model_dump() + + assert dict_a.keys() == dict_b.keys() + return self.model_validate( + {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} ) + @classmethod + def from_pretrained( + cls, model: Module, format: Optional[str] = None + ) -> "QuantizationConfig": + default_config = QuantizationConfig(config_groups={}) + config = getattr(model, "quantization_config", default_config) + + # silently override format + config.format = cls.validate_format(format) + return config + def requires_calibration_data(self): if self.kv_cache_scheme is not None: return True diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 01bc97c0d..65d545e9e 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import List, Optional +from typing import Any, List, Optional from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( @@ -92,6 +92,44 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": return model + def merge(self, other: "QuantizationScheme") -> "QuantizationScheme": + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "targets": + return value_a + value_b + + if field_name == "kv_cache_only": + # nones defer to other value + if value_a is None: + return value_b + if value_b is None: + return value_a + + # kv_cache_only=True overrides + return not ((not value_a) or (not value_b)) + + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if value_a == value_b: + return value_a + + raise ValueError( + "The following fields have overlapping targets and conflicting values " + f"for {field_name}. Please modify your config to resolve this " + f"ambiguity.\n{self}\n{other}" + ) + + dict_a = self.model_dump() + dict_b = other.model_dump() + + assert dict_a.keys() == dict_b.keys() + return self.model_validate( + {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} + ) + model_config = ConfigDict(extra="forbid") diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index c3830a02d..4ced6ee60 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest +from compressed_tensors import CompressionFormat from compressed_tensors.quantization import ( - DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, QuantizationScheme, @@ -29,7 +29,7 @@ def test_basic_config(): assert config.config_groups == config_groups assert config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert config.format == DEFAULT_QUANTIZATION_FORMAT + assert config.format == CompressionFormat.dense assert config.quantization_status == QuantizationStatus.INITIALIZED assert config.global_compression_ratio is None assert isinstance(config.ignore, list) and len(config.ignore) == 0 From b47cda043680af5eab09008d3d259dcf84f6ac17 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 19:08:29 -0400 Subject: [PATCH 14/42] comment Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c8c5a4e1d..a1c64b38e 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -165,7 +165,7 @@ def apply_quantization_config( if is_attention_module(submodule) and is_narrow_match( model, scheme.targets, name ): - # unlike linear, do qparam initialization here + # unlike linear, do qparam initialization (idempotent to reapplication) initialize_hooked_attention(model, submodule, quantize=True) # apply current quantization status across all targeted linear/embedding layers From c019e646e29709bc7b29ae980c4c7ce633173af3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 23:28:41 -0400 Subject: [PATCH 15/42] simplify hook replacement code Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 29 ++-------- src/compressed_tensors/modeling/kvcache.py | 60 ++++---------------- 2 files changed, 15 insertions(+), 74 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index 15e24a1ae..711f3952e 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -132,29 +132,10 @@ def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHan """ impl = getattr(module, IMPL_ATTR) - def _hook(mod: torch.nn.Module, args, kwargs): - # Keyword case - if "query" in kwargs: - kwargs["query"] = hook(mod, kwargs["query"]) - return args, kwargs - - # Positional case: find the index of `query` in impl.forward - sig = inspect.signature(mod.forward) - param_names = tuple(sig.parameters.keys()) - try: - idx = param_names.index("query") - except ValueError: - # No `query` parameter; nothing to do - return args, kwargs - - if idx < len(args): - args = list(args) - ret = hook(module, args[idx]) - if ret is not None: - args[idx] = ret - return tuple(args), kwargs - - # Not present explicitly (maybe defaulted) - return args, kwargs + def _hook(cache: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + bound.arguments["query"] = hook(cache, bound.arguments["query"]) + + return bound.args, bound return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 965ffa208..85539b715 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -117,31 +117,11 @@ def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) - def _hook(mod: torch.nn.Module, args, kwargs): - # If passed as keyword, this is easy. - if "key_states" in kwargs: - kwargs["key_states"] = hook(mod, kwargs["key_states"]) - return args, kwargs - - # Otherwise, find where key_states would be in positional args. - sig = inspect.signature(mod.forward) - param_names = tuple(sig.parameters.keys()) - try: - idx = param_names.index("key_states") - except ValueError: - # forward has no key_states parameter; do nothing - return args, kwargs - - # If the position exists in args, replace it. - if idx < len(args): - args = list(args) - ret = hook(module, args[idx]) - if ret is not None: - args[idx] = ret - return tuple(args), kwargs - - # Not present positionally and not in kwargs (maybe defaulted) — do nothing. - return args, kwargs + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + bound.arguments["key_states"] = hook(cache, bound.arguments["key_states"]) + + return bound.args, bound return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) @@ -149,30 +129,10 @@ def _hook(mod: torch.nn.Module, args, kwargs): def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) - def _hook(mod: torch.nn.Module, args, kwargs): - # If passed as keyword, this is easy. - if "value_states" in kwargs: - kwargs["value_states"] = hook(mod, kwargs["value_states"]) - return args, kwargs - - # Otherwise, find where value_states would be in positional args. - sig = inspect.signature(mod.forward) - param_names = tuple(sig.parameters.keys()) - try: - idx = param_names.index("value_states") - except ValueError: - # forward has no value_states parameter; do nothing - return args, kwargs - - # If the position exists in args, replace it. - if idx < len(args): - args = list(args) - ret = hook(module, args[idx]) - if ret is not None: - args[idx] = ret - return tuple(args), kwargs - - # Not present positionally and not in kwargs (maybe defaulted) — do nothing. - return args, kwargs + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + bound.arguments["value_states"] = hook(cache, bound.arguments["value_states"]) + + return bound.args, bound return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) From ceaf67708a8e54dfeac3d6d7fd86f23bff055d0b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 25 Aug 2025 23:29:11 -0400 Subject: [PATCH 16/42] fix typo Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index 711f3952e..c81aef1dc 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -132,9 +132,9 @@ def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHan """ impl = getattr(module, IMPL_ATTR) - def _hook(cache: QuantizedAttentionImpl, args, kwargs): - bound = inspect.signature(cache.forward).bind(*args, **kwargs) - bound.arguments["query"] = hook(cache, bound.arguments["query"]) + def _hook(impl: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(impl.forward).bind(*args, **kwargs) + bound.arguments["query"] = hook(impl, bound.arguments["query"]) return bound.args, bound From 29de3ec1505192a2195309cab63721b94d2bd940 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Aug 2025 02:00:08 -0400 Subject: [PATCH 17/42] do not attach scheme if not targeted Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index a1c64b38e..ff47348be 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -161,12 +161,14 @@ def apply_quantization_config( ) replace_module(model, name, compressed_linear) - # replace attention implementation and kvcache with hookable modules - if is_attention_module(submodule) and is_narrow_match( - model, scheme.targets, name - ): - # unlike linear, do qparam initialization (idempotent to reapplication) - initialize_hooked_attention(model, submodule, quantize=True) + # attention quantization and/or kv cache quantization + if is_attention_module(submodule): + if is_narrow_match(model, scheme.targets, name): + # unlike linear, do qparam initialization here (once) + initialize_hooked_attention(model, submodule, quantize=True) + else: + # do not quantize attention unless specifically targeted + delattr(submodule, "quantization_scheme") # apply current quantization status across all targeted linear/embedding layers apply_quantization_status(model, config.quantization_status) From e904883e8ef1bf94c20819d1a23c09784dc0ef40 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Aug 2025 13:30:34 -0400 Subject: [PATCH 18/42] revert format changes Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 13 +++-- .../quantization/quant_config.py | 45 +++++---------- .../transform/factory/base.py | 1 - src/compressed_tensors/utils/offload.py | 1 - .../test_model_compressor.py | 8 ++- .../test_quantization/lifecycle/test_apply.py | 56 ++----------------- tests/test_quantization/test_quant_config.py | 3 +- 7 files changed, 34 insertions(+), 93 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ff47348be..663347af3 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -173,11 +173,8 @@ def apply_quantization_config( # apply current quantization status across all targeted linear/embedding layers apply_quantization_status(model, config.quantization_status) - # attach for serialization - # do merginging using from_pretrained - if existing_config := getattr(model, "quantization_config", None): - config = config.merge(existing_config) - setattr(model, "quantization_config", config) + # attach config for serialization + attach_config(model, config) def attach_scheme(module: Module, scheme: QuantizationScheme) -> QuantizationScheme: @@ -188,6 +185,12 @@ def attach_scheme(module: Module, scheme: QuantizationScheme) -> QuantizationSch return scheme +def attach_config(model: PreTrainedModel, config: QuantizationConfig): + if existing_config := getattr(model, "quantization_config", None): + config = config.merge(existing_config) + setattr(model, "quantization_config", config) + + def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: """ Preprocess the raw QuantizationConfig diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 6f74109fb..ecdc8ae21 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -35,6 +35,7 @@ "QuantizationConfig", "LIFECYCLE_ORDER", "DEFAULT_QUANTIZATION_METHOD", + "DEFAULT_QUANTIZATION_FORMAT", ] @@ -60,8 +61,6 @@ class QuantizationStatus(str, Enum): def __ge__(self, other): if other is None: return True - if isinstance(other, str): - other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other) @@ -69,8 +68,6 @@ def __ge__(self, other): def __gt__(self, other): if other is None: return True - if isinstance(other, str): - other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other) @@ -78,8 +75,6 @@ def __gt__(self, other): def __lt__(self, other): if other is None: return False - if isinstance(other, str): - other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other) @@ -87,8 +82,6 @@ def __lt__(self, other): def __le__(self, other): if other is None: return False - if isinstance(other, str): - other = self.__class__(other) if not isinstance(other, self.__class__): raise NotImplementedError return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other) @@ -102,6 +95,7 @@ def __le__(self, other): ] DEFAULT_QUANTIZATION_METHOD = "compressed-tensors" +DEFAULT_QUANTIZATION_FORMAT = "fakequant" # TODO: remove class QuantizationConfig(BaseModel): @@ -137,7 +131,7 @@ class QuantizationConfig(BaseModel): config_groups: Dict[str, Union[QuantizationScheme, List[str]]] quant_method: str = DEFAULT_QUANTIZATION_METHOD kv_cache_scheme: Optional[QuantizationArgs] = None - format: str = CompressionFormat.dense.value + format: str = DEFAULT_QUANTIZATION_FORMAT quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) @@ -158,27 +152,6 @@ def model_post_init(self, __context): targets=targets_or_scheme, ) - @field_validator("format", mode="before") - def validate_format(cls, value: Any) -> str: - if value is None: - return CompressionFormat.dense.value - - if isinstance(value, list): - if len(value) == 0: - return CompressionFormat.dense.value - - if len(value) == 1: - assert isinstance(value[0], str) - return CompressionFormat(value[0]).value - - else: - return CompressionFormat.mixed_precision.value - - if isinstance(value, str): - return CompressionFormat(value).value - - return str(value) - def to_dict(self): # for compatibility with HFQuantizer return self.model_dump() @@ -186,7 +159,7 @@ def to_dict(self): def merge(self, other: "QuantizationConfig") -> "QuantizationConfig": def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: if field_name == "config_groups": - return value_a + value_b + return value_a | value_b if field_name == "ignore": if value_a is not None and value_b is None: @@ -235,7 +208,15 @@ def from_pretrained( config = getattr(model, "quantization_config", default_config) # silently override format - config.format = cls.validate_format(format) + if isinstance(format, list): + format = ( + CompressionFormat.mixed_precision.value + if len(format) > 1 + else format[0] + ) + if format is None: + format = CompressionFormat.dense.value + config.format = format return config def requires_calibration_data(self): diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index f612ded6a..9843a2549 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -125,7 +125,6 @@ def _apply_to_module( transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) self.transforms.append(transform) - register_offload_module(module, transform_name, transform) # register input transformation hook diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index a6b79d9b6..561705c89 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -133,7 +133,6 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device: return cast_to_device(submodule._hf_hook.weights_map[name].device) if param is not None: - assert param.device != torch.device("meta") return param.device warnings.warn(f"Unable to get offload device of {module}, falling back to CPU") diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index dc48870b3..9680d40ee 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -399,7 +399,8 @@ def _get_combined_config(s_config, q_config): ) def test_compress_model(model_stub, q_format, s_config, tmpdir): model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) - compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format]) + qformats = None if q_format is None else [q_format] # FUTURE: remove nullability + compressor = ModelCompressor.from_pretrained_model(model, s_config, qformats) # compress model by eagerly compressing state dict true_compressed = dict(compressor.compress(model)) @@ -446,8 +447,9 @@ def test_compress_model_meta(model_stub, q_format, s_config): cpu_model = AutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.float32 ) + qformats = None if q_format is None else [q_format] # FUTURE: remove nullability reference_compressor = ModelCompressor.from_pretrained_model( - cpu_model, s_config, [q_format] + cpu_model, s_config, qformats ) # Only stores dtype because meta model does not store values expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()} @@ -463,7 +465,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): module.to_empty(device="meta") # Compress in-place on meta model - compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format]) + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, qformats) compressor.compress_model(meta_model) # Compare keys and dtypes diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 09e12cc5a..ecd219404 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -57,53 +57,6 @@ def llama_stories_model(): ) -def test_target_prioritization(mock_frozen): - # tests that the config_groups are applied in the correct order - # of priority, where exact layer name > regex > module name - config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "config_groups": { - "group_1": { - "weights": { - "num_bits": 8, - }, - "targets": ["Linear"], - }, - "group_2": { - "weights": { - "num_bits": 4, - }, - "targets": ["re:.*down_proj"], - }, - "group_3": { - "weights": { - "num_bits": 2, - }, - "targets": ["model.layers.0.mlp.down_proj"], - }, - }, - } - - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", torch_dtype="auto" - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - mock_frozen(model) - - for name, module in model.named_modules(): - if name == "model.layers.0.mlp.down_proj": - assert module.quantization_scheme.weights.num_bits == 2 - elif re.match(".*down_proj", name): - assert module.quantization_scheme.weights.num_bits == 4 - elif isinstance(module, torch.nn.Linear): - assert module.quantization_scheme.weights.num_bits == 8 - - def test_apply_quantization_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config(status="calibration") model = get_tinyllama_model() @@ -174,13 +127,16 @@ def test_serialize_config_tinyllama(): serialized_config = QuantizationConfig.from_pretrained(model) assert len(serialized_config.config_groups) == 2 - assert serialized_config.config_groups["group_0"].targets == ["Embedding"] - assert serialized_config.config_groups["group_0"].input_activations is None assert serialized_config.config_groups["group_1"].targets == ["Linear"] assert serialized_config.config_groups["group_1"].input_activations is not None + assert serialized_config.config_groups["group_2"].targets == ["Embedding"] + assert serialized_config.config_groups["group_2"].input_activations is None assert serialized_config.format == CompressionFormat.dense.value assert serialized_config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] + assert serialized_config.ignore == [ + "LlamaRotaryEmbedding", + "model.layers.1.mlp.down_proj", + ] if serialized_config.global_compression_ratio is not None: assert serialized_config.global_compression_ratio > 1.0 assert serialized_config.global_compression_ratio < 8.0 diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index 4ced6ee60..cbe07183a 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -15,6 +15,7 @@ import pytest from compressed_tensors import CompressionFormat from compressed_tensors.quantization import ( + DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, QuantizationScheme, @@ -29,7 +30,7 @@ def test_basic_config(): assert config.config_groups == config_groups assert config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert config.format == CompressionFormat.dense + assert config.format == DEFAULT_QUANTIZATION_FORMAT assert config.quantization_status == QuantizationStatus.INITIALIZED assert config.global_compression_ratio is None assert isinstance(config.ignore, list) and len(config.ignore) == 0 From 6f91dd6285b5d8c2a3792929ed33efa043a1c0de Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 26 Aug 2025 14:02:21 -0400 Subject: [PATCH 19/42] fix typo Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 8 +++++--- src/compressed_tensors/modeling/kvcache.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index c81aef1dc..1d86e78b4 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -133,9 +133,11 @@ def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHan impl = getattr(module, IMPL_ATTR) def _hook(impl: QuantizedAttentionImpl, args, kwargs): - bound = inspect.signature(impl.forward).bind(*args, **kwargs) - bound.arguments["query"] = hook(impl, bound.arguments["query"]) + bound = inspect.signature(module.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["query"]) + if value is not None: + bound.arguments["query"] = value - return bound.args, bound + return bound.args, bound.kwargs return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 85539b715..f26fca084 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -119,9 +119,11 @@ def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandl def _hook(cache: QuantizedKVCache, args, kwargs): bound = inspect.signature(cache.forward).bind(*args, **kwargs) - bound.arguments["key_states"] = hook(cache, bound.arguments["key_states"]) + value = hook(module, bound.arguments["key_states"]) + if value is not None: + bound.arguments["key_states"] = value - return bound.args, bound + return bound.args, bound.kwargs return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) @@ -131,8 +133,10 @@ def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHan def _hook(cache: QuantizedKVCache, args, kwargs): bound = inspect.signature(cache.forward).bind(*args, **kwargs) - bound.arguments["value_states"] = hook(cache, bound.arguments["value_states"]) + value = hook(module, bound.arguments["value_states"]) + if value is not None: + bound.arguments["value_states"] = value - return bound.args, bound + return bound.args, bound.kwargs return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) From 154d2e485f4b647e2e444c670ba209f944e57e54 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 15:02:49 -0400 Subject: [PATCH 20/42] deprecate safe permute Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 6 +-- src/compressed_tensors/utils/permute.py | 43 ++---------------- .../lifecycle/test_helpers.py | 44 ++++++++++--------- 3 files changed, 30 insertions(+), 63 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 58a16dfba..2e539b070 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -29,7 +29,6 @@ calculate_range, compute_dynamic_scales_and_zp, ) -from compressed_tensors.utils import safe_permute from torch.nn import Module @@ -294,7 +293,7 @@ def _process_quantization( group_sizes = group_sizes[torch.argsort(group_indices)] perm = torch.argsort(g_idx) - x = safe_permute(x, perm, dim=1) + x = x.index_select(-1, perm) # Maintain all dimensions except the last dim, which is divided by group_size reshaped_dims = ( @@ -328,7 +327,8 @@ def _process_quantization( output = output.to(output_dtype) if not is_column_order: - output = safe_permute(output, torch.argsort(perm), dim=1) + inv_perm = torch.argsort(perm) + output = output.index_select(-1, inv_perm) else: # covers channel, token and tensor strategies if do_quantize: diff --git a/src/compressed_tensors/utils/permute.py b/src/compressed_tensors/utils/permute.py index e31d4862b..86a0ee805 100644 --- a/src/compressed_tensors/utils/permute.py +++ b/src/compressed_tensors/utils/permute.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Set, Tuple - import torch +from compressed_tensors.utils.helpers import deprecated __all__ = ["safe_permute"] -# these datatypes are missing implementations required for standard permutation -_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set() - - +@deprecated("Tensor.index_select") def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor: """ Perform out-of-place permutation without using torch.Tensor.index_put_, @@ -34,37 +30,4 @@ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch :param dim: dimension along which to apply permutation :return: permuted value """ - dtype_tuple = (value.dtype, value.device) - - if dtype_tuple in _EXPERIMENTAL_DTYPES: - return _fallback_permute(value, perm, dim) - - try: - return value[tuple([slice(None)] * dim + [perm])] - except RuntimeError: - # Mark dtype as experimental if advanced indexing fails - _EXPERIMENTAL_DTYPES.add(dtype_tuple) - return _fallback_permute(value, perm, dim) - - -def _fallback_permute( - value: torch.Tensor, perm: torch.Tensor, dim: int -) -> torch.Tensor: - """ - Fallback permutation method for experimental dtypes. - - :param value: tensor to permute - :param perm: permutation map - :param dim: dimension along which to apply permutation - :return: permuted value - """ - value_ret = value.clone() # cannot use zeros_like b/c of missing impl. - orig_slices = [slice(None)] * (dim + 1) - perm_slices = [slice(None)] * (dim + 1) - - for index, perm_index in enumerate(perm): - orig_slices[dim] = index - perm_slices[dim] = perm_index - value_ret[tuple(orig_slices)] = value[tuple(perm_slices)] - - return value_ret + return value.index_select(dim, perm) diff --git a/tests/test_quantization/lifecycle/test_helpers.py b/tests/test_quantization/lifecycle/test_helpers.py index 08d916544..20fd39da4 100644 --- a/tests/test_quantization/lifecycle/test_helpers.py +++ b/tests/test_quantization/lifecycle/test_helpers.py @@ -15,31 +15,35 @@ import pytest import torch -from compressed_tensors.utils import safe_permute -from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES +from compressed_tensors.utils.permute import safe_permute +from tests.testing_utils import requires_gpu +@requires_gpu +@pytest.mark.unit +@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize( - "dtype,device,exp_experimental", + "dtype", [ - (torch.int8, torch.device("cpu"), False), - (torch.int16, torch.device("cpu"), False), - (torch.int32, torch.device("cpu"), False), - (torch.int64, torch.device("cpu"), False), - (torch.float16, torch.device("cpu"), False), - (torch.float32, torch.device("cpu"), False), - (torch.float64, torch.device("cpu"), False), - (torch.float8_e4m3fn, torch.device("cpu"), True), + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bfloat16, + torch.float16, + torch.float32, + torch.float64, + torch.float8_e4m3fn, ], ) -def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool): - # some dtypes do not support arange initialization - tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device) - perm = torch.tensor([3, 1, 0, 2]) - expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device) +@pytest.mark.parametrize( + "device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")] +) +def test_safe_permute(dtype: torch.dtype, device: torch.device): + value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device) + perm = torch.tensor([3, 1, 0, 2], device=device) - result = safe_permute(tensor, perm, dim=0) + result = safe_permute(value, perm, dim=-1) - if exp_experimental: - assert (dtype, device) in _EXPERIMENTAL_DTYPES - assert all(result == expected) + if device.type != "meta": + assert torch.equal(result.squeeze(0), perm.to(result.dtype)) From ed8f5dcbde96bc510047315bfb49ca218cd8a548 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 15:49:02 -0400 Subject: [PATCH 21/42] meta hadamards Signed-off-by: Kyle Sayers --- .../transform/factory/base.py | 36 ++----------------- .../transform/utils/hadamard.py | 7 ++-- 2 files changed, 7 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 94e6b4a42..e0a6978fe 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,8 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections import defaultdict -from typing import List, Optional, Set, Tuple +from typing import List, Optional, Set import torch import torch.nn.utils.parametrize as P @@ -101,8 +100,6 @@ def apply_to_model(self, model: Module, use_tqdm=True): for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): self._apply_to_module(module, arg) - self._update_tied_weights() - def _apply_to_module(self, module: Module, args: TransformArgs): """ Create transforms and apply them to the module @@ -165,31 +162,6 @@ def output_hook(_, _input, output): else: raise NotImplementedError() - def _update_tied_weights(self): - """ - Populate the `_dynamic_tied_weights_keys` attribute of transforms, - which is used by transformers to detect and remove shared pointers - during saving - """ - # map from data_ptrs to keys - ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list) - for transform in self.transforms: - for name, param in transform.named_parameters(recurse=False): - # NOTE: previously asserted that parent._hf_hook.place_submodules=False - if has_offloaded_params(transform): - param = transform._hf_hook.weights_map[name] - ptr_to_keys[param.data_ptr()].append((transform, name)) - - # populate `_dynamic_tied_weights_keys` if there is more than one key - # and ensure that they share tensors - for shared_keys in ptr_to_keys.values(): - if len(shared_keys) > 1: - tensor = getattr(shared_keys[0][0], shared_keys[0][1]) - - for transform, name in shared_keys: - transform._dynamic_tied_weights_keys.add(name) - setattr(transform, name, tensor) - class TransformBase(InternalModule, ABC): """ @@ -198,11 +170,7 @@ class TransformBase(InternalModule, ABC): args: TransformArgs weight: Parameter - _dynamic_tied_weights_keys: Set[str] - - def __init__(self): - super().__init__() - self._dynamic_tied_weights_keys = set() + _dynamic_tied_weights_keys: List[str] = ["weight"] @abstractmethod def forward(self, value: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 7d361e59d..c8144ae26 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -115,13 +115,16 @@ def _fetch_hadamard_divisor( than forcing callers to manage the file open context :param n: size of known hadamard matrix + :param dtype: data type to move fetched hadamard to + :param device: device to move fetched hadamard to :return: a known hadamard matrix of size `n` if one exists, else None """ - with safe_open(file_path, framework="pt", device=str(device)) as file: + open_device = torch.device("cpu") if device.type == "meta" else device + with safe_open(file_path, framework="pt", device=str(open_device)) as file: divisors = sorted((int(key) for key in file.keys()), reverse=True) for divisor in divisors: if n % divisor == 0 and is_pow2(n // divisor): - return file.get_tensor(str(divisor)).to(dtype=dtype) + return file.get_tensor(str(divisor)).to(dtype=dtype, device=device) return None From dfdbd3f0a719b40efc159a8834865078cea7cc8e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 16:02:11 -0400 Subject: [PATCH 22/42] fix dynamic weights keys Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/hadamard.py | 4 +++- tests/test_transform/factory/test_serialization.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index de6e284bb..a843e2728 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -84,6 +84,8 @@ def _create_permutation(self, weight: Parameter) -> Parameter: class HadamardTransform(TransformBase): + _dynamic_tied_weights_keys: List[str] = ["weight", "perm"] + def __init__( self, weight: Parameter, diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py index a688c2cf1..7a9ba645b 100644 --- a/tests/test_transform/factory/test_serialization.py +++ b/tests/test_transform/factory/test_serialization.py @@ -44,8 +44,6 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False): @pytest.mark.skip(reason="Requires changes in upstream transformers") -# https://github.com/huggingface/transformers/pull/39280 -# https://github.com/huggingface/transformers/pull/39263 @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) From 7aec12c7d67b088df74f14cdacfd1e45de32ac71 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 17:59:29 -0400 Subject: [PATCH 23/42] break out _tie_offloaded_tensors, add test Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/apply.py | 36 +++++++++++++ .../transform/factory/base.py | 4 +- tests/test_transform/conftest.py | 2 + .../factory/test_serialization.py | 52 +++++++++++++++++-- 4 files changed, 88 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index e247e7029..39c323b4e 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict +from typing import List, Tuple + import torch from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -34,3 +37,36 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): # attach config to model for compression/serialization setattr(model, TRANSFORM_CONFIG_NAME, config) + + # ensure that tied weight transforms can be serialized without aliases + # In the future, this could be done by transformers or model compressor + # which would make this more robust to changing dispatches after transforms + _tie_offloaded_tensors(model) + + +def _tie_offloaded_tensors(model: torch.nn.Module): + """ + Populate the `_dynamic_tied_weights_keys` attribute of transforms, + which is used by transformers to detect and remove shared pointers + during saving + """ + from compressed_tensors.utils import has_offloaded_params + + # map from to keys + offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list) + for module in model.modules(): + # NOTE: previously asserted that parent._hf_hook.place_submodules=False + if has_offloaded_params(module): + for key, _ in module.named_parameters(recurse=False): + param = module._hf_hook.weights_map[key] + offloaded_ptrs[id(param)].append((module, key)) + + # populate `_dynamic_tied_weights_keys` if there is more than one key + # and ensure that they share tensors. In the case of offloading, this + for shared_keys in offloaded_ptrs.values(): + if len(shared_keys) > 1: + first_tensor = getattr(shared_keys[0][0], shared_keys[0][1]) + assert first_tensor.device.type == "meta" + for module, key in shared_keys: + assert getattr(module, key).device.type == "meta" + setattr(module, key, first_tensor) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index e0a6978fe..34d609e74 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import List, Optional, Set +from typing import List, Optional import torch import torch.nn.utils.parametrize as P @@ -56,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non self.name = name self.scheme = scheme self.generator = torch.Generator() - self.transforms = list() if seed is not None: self.generator.manual_seed(seed) @@ -117,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) - self.transforms.append(transform) register_offload_module(module, transform_name, transform) # register input transformation hook diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index a0188c429..824c06bd3 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -19,6 +19,8 @@ class TransformableModel(PreTrainedModel): + config_class = PretrainedConfig + def __init__(self, *sizes): super().__init__(config=PretrainedConfig()) self.fcs = torch.nn.ModuleList( diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py index 7a9ba645b..15fa240ba 100644 --- a/tests/test_transform/factory/test_serialization.py +++ b/tests/test_transform/factory/test_serialization.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from compressed_tensors.transform import ( @@ -20,7 +22,9 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch +from safetensors import safe_open from tests.testing_utils import requires_accelerate, requires_gpu +from transformers import AutoModelForCausalLM, AutoTokenizer @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @@ -38,15 +42,57 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False): apply_transform_config(model, config) # save model - model.save_pretrained(tmp_path) + model_path = os.path.join(tmp_path, "test_model_path") + model.save_pretrained(model_path) + + # check that saved values match model values + # note that shared weights are only serialized once + safetensors_path = os.path.join(model_path, "model.safetensors") + with safe_open(safetensors_path, framework="pt", device="cpu") as file: + saved_keys = set(file.keys()) + assert { + "fcs.0.weight", + "fcs.1.weight", + "fcs.2.weight", + "fcs.3.weight", + "fcs.4.weight", + } <= saved_keys + for key in saved_keys: + param = model.get_parameter(key) + saved_param = file.get_tensor(key) - # TODO: reload model + if param.device.type != "meta": # skip testing values in offload case + assert torch.equal(param, saved_param) -@pytest.mark.skip(reason="Requires changes in upstream transformers") @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomize", (True, False)) def test_serialization_offload(type, randomize, model_apply, tmp_path): test_serialization(type, randomize, model_apply, tmp_path, offload=True) + + +@pytest.mark.skip("Requires transformers#40673") +@requires_gpu +@pytest.mark.parametrize( + "model_stub,exp_perplexity", + [ + ("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0), + ("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0), + ], +) +def test_load_perplexity(model_stub, exp_perplexity): + model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(model_stub) + + prompt = "The capital of France is Paris, the capital of Germany is Berlin" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {key: value.to(model.device) for key, value in inputs.items()} + labels = inputs["input_ids"] + + with torch.no_grad(): + outputs = model(**inputs, labels=labels) + + perplexity = torch.exp(outputs.loss) + assert perplexity <= exp_perplexity From 2bf33c8e62b75f279af6de273f54dc079250f08f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 18:11:18 -0400 Subject: [PATCH 24/42] better comments Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/apply.py | 33 ++++++++++++++--------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 39c323b4e..0ebe96ec4 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -16,6 +16,7 @@ from typing import List, Tuple import torch +from accelerate.utils import has_offloaded_params from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -46,13 +47,19 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): def _tie_offloaded_tensors(model: torch.nn.Module): """ - Populate the `_dynamic_tied_weights_keys` attribute of transforms, - which is used by transformers to detect and remove shared pointers - during saving + When accelerate replaces tensors with meta tensors during offloading, the meta + tensors may not be identical, even if the offloaded values are identical. + + However, transformers can only serialize correctly if meta tensors are identical + (see transformers#39263). + + This function collects all meta tensors which have shared offloaded values and sets + those tensors to be identical so that they can be removed during serialization + + :param model: model potentially containing offloaded meta tensors to fix """ - from compressed_tensors.utils import has_offloaded_params - # map from to keys + # map from offloaded tensor pointers to module-key locations offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list) for module in model.modules(): # NOTE: previously asserted that parent._hf_hook.place_submodules=False @@ -61,12 +68,12 @@ def _tie_offloaded_tensors(model: torch.nn.Module): param = module._hf_hook.weights_map[key] offloaded_ptrs[id(param)].append((module, key)) - # populate `_dynamic_tied_weights_keys` if there is more than one key - # and ensure that they share tensors. In the case of offloading, this + # ensure that if a location shares an offloaded tensor pointers, that the + # meta tensor is also identical (assigned to the first element of the set) for shared_keys in offloaded_ptrs.values(): - if len(shared_keys) > 1: - first_tensor = getattr(shared_keys[0][0], shared_keys[0][1]) - assert first_tensor.device.type == "meta" - for module, key in shared_keys: - assert getattr(module, key).device.type == "meta" - setattr(module, key, first_tensor) + assert len(shared_keys) >= 1 + first_tensor = getattr(shared_keys[0][0], shared_keys[0][1]) + assert first_tensor.device.type == "meta" + for module, key in shared_keys: + assert getattr(module, key).device.type == "meta" + setattr(module, key, first_tensor) From 33b71b30b475a858108f2a80b1da4c1ad53a7e33 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 18:12:11 -0400 Subject: [PATCH 25/42] better comments Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 0ebe96ec4..99bcaf760 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -62,7 +62,6 @@ def _tie_offloaded_tensors(model: torch.nn.Module): # map from offloaded tensor pointers to module-key locations offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list) for module in model.modules(): - # NOTE: previously asserted that parent._hf_hook.place_submodules=False if has_offloaded_params(module): for key, _ in module.named_parameters(recurse=False): param = module._hf_hook.weights_map[key] From 2ef1ab2731e0c312e2365b65c5746e9e4ab608a3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 14:41:22 -0400 Subject: [PATCH 26/42] simplify function Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/apply.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 99bcaf760..035ca25dd 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import List, Tuple +from typing import Dict, List, Tuple import torch from accelerate.utils import has_offloaded_params @@ -59,20 +59,14 @@ def _tie_offloaded_tensors(model: torch.nn.Module): :param model: model potentially containing offloaded meta tensors to fix """ - # map from offloaded tensor pointers to module-key locations - offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list) + # ensure that if a location shares an offloaded tensor pointers, that the + # meta tensor is also identical (assigned to the first instance of parameter) + ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() for module in model.modules(): if has_offloaded_params(module): for key, _ in module.named_parameters(recurse=False): - param = module._hf_hook.weights_map[key] - offloaded_ptrs[id(param)].append((module, key)) + offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() - # ensure that if a location shares an offloaded tensor pointers, that the - # meta tensor is also identical (assigned to the first element of the set) - for shared_keys in offloaded_ptrs.values(): - assert len(shared_keys) >= 1 - first_tensor = getattr(shared_keys[0][0], shared_keys[0][1]) - assert first_tensor.device.type == "meta" - for module, key in shared_keys: - assert getattr(module, key).device.type == "meta" - setattr(module, key, first_tensor) + if offloaded_ptr not in ptr_to_meta: + ptr_to_meta[offloaded_ptr] = getattr(module, key) + setattr(module, key, ptr_to_meta[offloaded_ptr]) From a11770a9478eeaa18622cbb1a56648cd092a5441 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 14:42:33 -0400 Subject: [PATCH 27/42] style Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/apply.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 035ca25dd..28d5e94fc 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import Dict, List, Tuple +from typing import Dict import torch from accelerate.utils import has_offloaded_params From 5438a81faba3154c8e05d7d578d4a5e7a5261975 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 21:01:44 -0400 Subject: [PATCH 28/42] better type hints, warn once Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index ff2c6fc27..b371eacfa 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -20,6 +20,7 @@ import numpy import torch from frozendict import frozendict +from loguru import logger from transformers import AutoConfig @@ -195,7 +196,7 @@ def decorator(func: T) -> T: @wraps(func) def wrapped(*args, **kwargs): - warnings.warn(message, DeprecationWarning, stacklevel=2) + logger.bind(log_once=True).warning(message) return func(*args, **kwargs) return wrapped From 5875644f9053a87ea6fcfe75f84504c81c4c7b22 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 21:06:15 -0400 Subject: [PATCH 29/42] remove unneeded import Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index b371eacfa..c2c9b292b 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib -import warnings from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar From f179a91a2e59ecb27709e87b53810a40ad556f75 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 19:02:10 -0400 Subject: [PATCH 30/42] allow-group-dynamic-quantization Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c0..a7c3d590e 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -66,15 +66,6 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): - if ( - inputs.strategy == QuantizationStrategy.GROUP - and inputs.dynamic is True - ): - raise NotImplementedError( - "Static and local group-wise activation " - "quantization is not supported" - ) - raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" From f3d0e58c96b3f0ccfa644e9ecac4d5e5c0fff9fd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 19:08:11 -0400 Subject: [PATCH 31/42] satisfy quality checker Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a7c3d590e..17dab844e 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -66,6 +66,17 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): +<<<<<<< HEAD +======= + if ( + inputs.strategy == QuantizationStrategy.GROUP + and inputs.dynamic is True + ): + raise NotImplementedError( + "Static and local group-wise quantization is not supported" + ) + +>>>>>>> db46c84 (satisfy quality checker) raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" From 8d357945dcf909f41f8cb3cedc18afd0ef6360d5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 19:08:57 -0400 Subject: [PATCH 32/42] more clear Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 17dab844e..505bdb2a9 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -73,7 +73,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": and inputs.dynamic is True ): raise NotImplementedError( - "Static and local group-wise quantization is not supported" + "Static and local group-wise activation " + "quantization is not supported" ) >>>>>>> db46c84 (satisfy quality checker) From 82ee671cf429cba6991afbfad8311645d21dd187 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 20:55:27 -0400 Subject: [PATCH 33/42] basic support Signed-off-by: Kyle Sayers --- src/compressed_tensors/__init__.py | 1 + .../quantization/lifecycle/forward.py | 5 +-- .../lifecycle/test_forward.py | 39 +++++++++++++++++-- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/__init__.py b/src/compressed_tensors/__init__.py index c892e81a9..08b0dfb71 100644 --- a/src/compressed_tensors/__init__.py +++ b/src/compressed_tensors/__init__.py @@ -19,6 +19,7 @@ from .compressors import * from .config import * +from .logger import LoggerConfig, configure_logger, logger from .quantization import QuantizationConfig, QuantizationStatus from .utils import * from .version import * diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2e539b070..e973f39b7 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -264,8 +264,7 @@ def _process_quantization( ): output_dtype = dtype if dtype is not None else x.dtype - output = torch.zeros_like(x).to(output_dtype) - columns = output.shape[-1] + columns = x.size(-1) # TODO: make validation step for inputs @@ -323,7 +322,7 @@ def _process_quantization( global_scale=global_scale, ) - output = output.flatten(start_dim=-2) + output = output.flatten(-2, -1) output = output.to(output_dtype) if not is_column_order: diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index f3321cd40..09010af06 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -95,7 +95,7 @@ def test_forward_quantize( @pytest.mark.parametrize( - "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale", + "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale,batch_size", [ ( 4, @@ -106,6 +106,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 4, @@ -116,6 +117,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 4, @@ -126,6 +128,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -136,6 +139,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 8, @@ -146,6 +150,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -156,6 +161,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -166,6 +172,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -176,17 +183,41 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, + ), + ( + 8, + "int", + QuantizationStrategy.GROUP, + 128, + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), + make_dummy_g_idx(1024, 128), + None, + 5, ), ], ) -def test_fake_quantize_2d( - num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale +def test_fake_quantize( + num_bits, + type, + strategy, + group_size, + scale, + zero_point, + g_idx, + global_scale, + batch_size, ): args = QuantizationArgs( num_bits=num_bits, type=type, strategy=strategy, group_size=group_size ) - x = torch.rand((512, 1024)) + if batch_size is None: + x = torch.rand((512, 1024)) + else: + x = torch.rand((batch_size, 512, 1024)) + fake_quantize( x=x, scale=scale, From c123637054b0458d7d4e7ac156df5bd18933d170 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 21:10:36 -0400 Subject: [PATCH 34/42] fix merge Signed-off-by: Kyle Sayers --- src/compressed_tensors/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/__init__.py b/src/compressed_tensors/__init__.py index 08b0dfb71..c892e81a9 100644 --- a/src/compressed_tensors/__init__.py +++ b/src/compressed_tensors/__init__.py @@ -19,7 +19,6 @@ from .compressors import * from .config import * -from .logger import LoggerConfig, configure_logger, logger from .quantization import QuantizationConfig, QuantizationStatus from .utils import * from .version import * From 1016a752bf48751d802f0b17d2f9f4ff70609a23 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 9 Sep 2025 09:04:16 -0400 Subject: [PATCH 35/42] ungate group activation quant Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 505bdb2a9..a7c3d590e 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -66,18 +66,6 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): -<<<<<<< HEAD -======= - if ( - inputs.strategy == QuantizationStrategy.GROUP - and inputs.dynamic is True - ): - raise NotImplementedError( - "Static and local group-wise activation " - "quantization is not supported" - ) - ->>>>>>> db46c84 (satisfy quality checker) raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" From 1c217e4d774a2db0cec00adeaf483e3c33484184 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 9 Sep 2025 13:02:09 -0400 Subject: [PATCH 36/42] refactor Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 3 +- .../quantization/lifecycle/initialize.py | 216 +++++++++--------- .../quantization/quant_args.py | 40 ++-- 3 files changed, 139 insertions(+), 120 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index faa48df20..63ddfc9d6 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -221,7 +221,8 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply( lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init + module, + force_zero_point=force_zero_point_init, ) ) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5350b4a2c..57009f93c 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -26,6 +26,7 @@ from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, + DynamicType, QuantizationArgs, QuantizationStrategy, ) @@ -73,10 +74,8 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return if is_attention_module(module): @@ -84,38 +83,49 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: - if scheme.input_activations is not None: - _initialize_scale_zero_point( - module, - "input", - scheme.input_activations, - force_zero_point=force_zero_point, + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" ) + return + + if scheme.input_activations is not None: + base_name = "input" + args = scheme.input_activations + observed_shape = weight.shape[-1:] + observed_dtype = weight.dtype if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) + base_name = "weight" + args = scheme.weights + observed_shape = weight.shape + observed_dtype = weight.dtype if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations - ) + base_name = "output" + args = scheme.output_activations + observed_shape = weight.shape[:-1] + observed_dtype = weight.dtype + + if not is_kv_cache_quant_scheme(scheme): + _initialize_scale_zero_point( + module, + base_name, + args, + observed_shape=observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -138,18 +148,21 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, + observed_shape: torch.Size, + observed_dtype: torch.dtype, force_zero_point: bool = True, ): - if quantization_args.dynamic is True: - return + strategy = quantization_args.strategy + dynamic = quantization_args.dynamic + actorder = quantization_args.actorder + device = get_execution_device(module) # avoid performing intialization ops on cpu - # initialize on execution device to avoid performing quantized ops on cpu - device = get_execution_device(module) + # Skip all intialization for fully dynamic quantization + if dynamic is True: + return - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + # 0. Create global scale for tensor-group quantization + if strategy == QuantizationStrategy.TENSOR_GROUP: init_global_scale = Parameter( torch.empty(1, dtype=torch.float32, device=device), requires_grad=False, @@ -158,56 +171,54 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # 2. Infer expected scale/zero point shape - if quantization_args.strategy == QuantizationStrategy.TOKEN: - expected_shape = (1, 1) - else: - expected_shape = 1 - - if base_name == "weight" and weight_shape is not None: - if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) - only for weights - expected_shape = (weight_shape[0], 1) - elif quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - # GROUP/TENSOR_GROUP for both weights and activations - num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) - expected_shape = (weight_shape[0], max(num_groups, 1)) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only - # for weights - if quantization_args.block_structure is None: - raise ValueError( - "Block quantization requires block_structure to be specified" - ) - block_height, block_width = quantization_args.block_structure - rows, cols = weight_shape[-2], weight_shape[-1] - num_rows_blocks = math.ceil(rows / block_height) - num_cols_blocks = math.ceil(cols / block_width) - - # Warn if dimensions don't divide evenly - if rows % block_height != 0 or cols % block_width != 0: - warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide" - f"evenly by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization" - "quality.", - UserWarning, - ) - - expected_shape = (num_rows_blocks, num_cols_blocks) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - warnings.warn( - f"BLOCK quantization not supported for {base_name} activations. " - f"Falling back to tensor-level quantization.", - UserWarning, - ) - expected_shape = 1 + # Skip scale/zp initialization for locally dynamic quantization + if dynamic == DynamicType.LOCAL: + return + + # 1. Infer expected scale/zp shape + if strategy in (QuantizationStrategy.TENSOR, QuantizationStrategy.TOKEN): + expected_shape = (1,) + + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 1: + raise ValueError("Channel quant requires at least 1 observed dimension") + expected_shape = (observed_shape[-1], 1) + +<<<<<<< HEAD # 3. Identify quantization scale and zp dtype scale_dtype = module.weight.dtype +======= + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = _strict_divide(observed_shape[-1], group_size, strategy) + expected_shape = (num_groups, group_size) + + # initialize activation ordering if applicable + if actorder == ActivationOrdering.GROUP: + init_g_idx = Parameter( + torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + elif strategy == QuantizationStrategy.BLOCK: + assert quantization_args.block_structure is not None + if len(observed_shape) < 2: + raise ValueError("Block quant requires at least 2 observed dimensions") + + block_structure = quantization_args.block_structure + num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy) + num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy) + expected_shape = (num_rows, num_cols) + + # 2. Identify quantization scale and zp dtype + scale_dtype = observed_dtype +>>>>>>> fde779c (refactor) if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -223,14 +234,12 @@ def _initialize_scale_zero_point( scale_dtype = torch.bfloat16 zp_dtype = quantization_args.pytorch_dtype() - # 4. Initializes empty scale, zero point, and g_idx parameters for the module - # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # 3. Initializes scale/zp for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( @@ -239,16 +248,6 @@ def _initialize_scale_zero_point( ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) - # only grouped activation ordering has g_idx - if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" @@ -270,3 +269,16 @@ def _initialize_attn_scales(module: Module) -> None: requires_grad=False, ) register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) + + +def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int: + out = observed // divisor + if out * divisor != observed: + raise ValueError( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {observed} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with weights " + f"not divisible by {divisor}" + ) + + return out diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353b..c55ee5efc 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -262,6 +262,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": actorder = model.actorder dynamic = model.dynamic observer = model.observer + block_structure = model.block_structure # infer strategy if strategy is None: @@ -277,23 +278,28 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": "strategy='group' and group_size = -1 for 'channel'" ) - # validate strategy and group - if strategy == QuantizationStrategy.GROUP: - if group_size is None or group_size <= 0: - raise ValueError( - f"strategy {strategy} requires group_size to be " - "set to a positive value" - ) - if ( - group_size is not None - and group_size > 0 - and strategy - not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP) - ): - raise ValueError("group_size requires strategy to be set to 'group'") - - # validate activation ordering and strategy - if actorder is not None and strategy != QuantizationStrategy.GROUP: + # validate block strategy and structure + has_block_strategy = strategy == QuantizationStrategy.BLOCK + has_block_structure = block_structure is not None + if has_block_strategy != has_block_structure: + raise ValueError( + "Block strategy requires `block_structure`, and vice versa. " + f"Instead got ({strategy}, {block_structure})" + ) + + # validate group strategy + has_group_strategy = strategy in ( + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ) + has_group_size = group_size is not None and group_size > 0 + has_actorder = actorder is not None + if has_group_strategy != has_group_size: + raise ValueError( + "Group strategies require `group_size`, and vice versa. " + f"Instead got ({strategy}, {group_size})" + ) + if has_actorder and not has_group_strategy: raise ValueError( "Must use group quantization strategy in order to apply " "activation ordering" From e5447f33046ae51a2a28979cfc02bd4446f2336a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 15:14:29 -0400 Subject: [PATCH 37/42] fix merge Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 57009f93c..37c6f19a1 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -185,10 +185,6 @@ def _initialize_scale_zero_point( expected_shape = (observed_shape[-1], 1) -<<<<<<< HEAD - # 3. Identify quantization scale and zp dtype - scale_dtype = module.weight.dtype -======= elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): assert quantization_args.group_size is not None if len(observed_shape) < 1: @@ -218,7 +214,6 @@ def _initialize_scale_zero_point( # 2. Identify quantization scale and zp dtype scale_dtype = observed_dtype ->>>>>>> fde779c (refactor) if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype From 6672617d535221c1087031be93840a3d5c04a1bd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 15:25:36 -0400 Subject: [PATCH 38/42] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 3 +- src/compressed_tensors/transform/apply.py | 35 ------------ .../transform/factory/base.py | 38 ++++++++++++- .../transform/factory/hadamard.py | 4 +- .../transform/utils/hadamard.py | 7 +-- src/compressed_tensors/utils/helpers.py | 4 +- tests/test_transform/conftest.py | 2 - .../factory/test_serialization.py | 54 ++----------------- 8 files changed, 47 insertions(+), 100 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 63ddfc9d6..faa48df20 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -221,8 +221,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply( lambda module: initialize_module_for_quantization( - module, - force_zero_point=force_zero_point_init, + module, force_zero_point=force_zero_point_init ) ) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 28d5e94fc..e247e7029 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict - import torch -from accelerate.utils import has_offloaded_params from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -37,35 +34,3 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): # attach config to model for compression/serialization setattr(model, TRANSFORM_CONFIG_NAME, config) - - # ensure that tied weight transforms can be serialized without aliases - # In the future, this could be done by transformers or model compressor - # which would make this more robust to changing dispatches after transforms - _tie_offloaded_tensors(model) - - -def _tie_offloaded_tensors(model: torch.nn.Module): - """ - When accelerate replaces tensors with meta tensors during offloading, the meta - tensors may not be identical, even if the offloaded values are identical. - - However, transformers can only serialize correctly if meta tensors are identical - (see transformers#39263). - - This function collects all meta tensors which have shared offloaded values and sets - those tensors to be identical so that they can be removed during serialization - - :param model: model potentially containing offloaded meta tensors to fix - """ - - # ensure that if a location shares an offloaded tensor pointers, that the - # meta tensor is also identical (assigned to the first instance of parameter) - ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() - for module in model.modules(): - if has_offloaded_params(module): - for key, _ in module.named_parameters(recurse=False): - offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() - - if offloaded_ptr not in ptr_to_meta: - ptr_to_meta[offloaded_ptr] = getattr(module, key) - setattr(module, key, ptr_to_meta[offloaded_ptr]) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 34d609e74..94e6b4a42 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,7 +13,8 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import List, Optional +from collections import defaultdict +from typing import List, Optional, Set, Tuple import torch import torch.nn.utils.parametrize as P @@ -56,6 +57,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non self.name = name self.scheme = scheme self.generator = torch.Generator() + self.transforms = list() if seed is not None: self.generator.manual_seed(seed) @@ -99,6 +101,8 @@ def apply_to_model(self, model: Module, use_tqdm=True): for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): self._apply_to_module(module, arg) + self._update_tied_weights() + def _apply_to_module(self, module: Module, args: TransformArgs): """ Create transforms and apply them to the module @@ -116,6 +120,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) + self.transforms.append(transform) register_offload_module(module, transform_name, transform) # register input transformation hook @@ -160,6 +165,31 @@ def output_hook(_, _input, output): else: raise NotImplementedError() + def _update_tied_weights(self): + """ + Populate the `_dynamic_tied_weights_keys` attribute of transforms, + which is used by transformers to detect and remove shared pointers + during saving + """ + # map from data_ptrs to keys + ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list) + for transform in self.transforms: + for name, param in transform.named_parameters(recurse=False): + # NOTE: previously asserted that parent._hf_hook.place_submodules=False + if has_offloaded_params(transform): + param = transform._hf_hook.weights_map[name] + ptr_to_keys[param.data_ptr()].append((transform, name)) + + # populate `_dynamic_tied_weights_keys` if there is more than one key + # and ensure that they share tensors + for shared_keys in ptr_to_keys.values(): + if len(shared_keys) > 1: + tensor = getattr(shared_keys[0][0], shared_keys[0][1]) + + for transform, name in shared_keys: + transform._dynamic_tied_weights_keys.add(name) + setattr(transform, name, tensor) + class TransformBase(InternalModule, ABC): """ @@ -168,7 +198,11 @@ class TransformBase(InternalModule, ABC): args: TransformArgs weight: Parameter - _dynamic_tied_weights_keys: List[str] = ["weight"] + _dynamic_tied_weights_keys: Set[str] + + def __init__(self): + super().__init__() + self._dynamic_tied_weights_keys = set() @abstractmethod def forward(self, value: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index a843e2728..de6e284bb 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -84,8 +84,6 @@ def _create_permutation(self, weight: Parameter) -> Parameter: class HadamardTransform(TransformBase): - _dynamic_tied_weights_keys: List[str] = ["weight", "perm"] - def __init__( self, weight: Parameter, diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index c8144ae26..7d361e59d 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -115,16 +115,13 @@ def _fetch_hadamard_divisor( than forcing callers to manage the file open context :param n: size of known hadamard matrix - :param dtype: data type to move fetched hadamard to - :param device: device to move fetched hadamard to :return: a known hadamard matrix of size `n` if one exists, else None """ - open_device = torch.device("cpu") if device.type == "meta" else device - with safe_open(file_path, framework="pt", device=str(open_device)) as file: + with safe_open(file_path, framework="pt", device=str(device)) as file: divisors = sorted((int(key) for key in file.keys()), reverse=True) for divisor in divisors: if n % divisor == 0 and is_pow2(n // divisor): - return file.get_tensor(str(divisor)).to(dtype=dtype, device=device) + return file.get_tensor(str(divisor)).to(dtype=dtype) return None diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index c2c9b292b..ff2c6fc27 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,13 +13,13 @@ # limitations under the License. import contextlib +import warnings from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar import numpy import torch from frozendict import frozendict -from loguru import logger from transformers import AutoConfig @@ -195,7 +195,7 @@ def decorator(func: T) -> T: @wraps(func) def wrapped(*args, **kwargs): - logger.bind(log_once=True).warning(message) + warnings.warn(message, DeprecationWarning, stacklevel=2) return func(*args, **kwargs) return wrapped diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index 824c06bd3..a0188c429 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -19,8 +19,6 @@ class TransformableModel(PreTrainedModel): - config_class = PretrainedConfig - def __init__(self, *sizes): super().__init__(config=PretrainedConfig()) self.fcs = torch.nn.ModuleList( diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py index 15fa240ba..a688c2cf1 100644 --- a/tests/test_transform/factory/test_serialization.py +++ b/tests/test_transform/factory/test_serialization.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest import torch from compressed_tensors.transform import ( @@ -22,9 +20,7 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch -from safetensors import safe_open from tests.testing_utils import requires_accelerate, requires_gpu -from transformers import AutoModelForCausalLM, AutoTokenizer @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @@ -42,57 +38,17 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False): apply_transform_config(model, config) # save model - model_path = os.path.join(tmp_path, "test_model_path") - model.save_pretrained(model_path) - - # check that saved values match model values - # note that shared weights are only serialized once - safetensors_path = os.path.join(model_path, "model.safetensors") - with safe_open(safetensors_path, framework="pt", device="cpu") as file: - saved_keys = set(file.keys()) - assert { - "fcs.0.weight", - "fcs.1.weight", - "fcs.2.weight", - "fcs.3.weight", - "fcs.4.weight", - } <= saved_keys - for key in saved_keys: - param = model.get_parameter(key) - saved_param = file.get_tensor(key) + model.save_pretrained(tmp_path) - if param.device.type != "meta": # skip testing values in offload case - assert torch.equal(param, saved_param) + # TODO: reload model +@pytest.mark.skip(reason="Requires changes in upstream transformers") +# https://github.com/huggingface/transformers/pull/39280 +# https://github.com/huggingface/transformers/pull/39263 @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomize", (True, False)) def test_serialization_offload(type, randomize, model_apply, tmp_path): test_serialization(type, randomize, model_apply, tmp_path, offload=True) - - -@pytest.mark.skip("Requires transformers#40673") -@requires_gpu -@pytest.mark.parametrize( - "model_stub,exp_perplexity", - [ - ("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0), - ("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0), - ], -) -def test_load_perplexity(model_stub, exp_perplexity): - model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda") - tokenizer = AutoTokenizer.from_pretrained(model_stub) - - prompt = "The capital of France is Paris, the capital of Germany is Berlin" - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {key: value.to(model.device) for key, value in inputs.items()} - labels = inputs["input_ids"] - - with torch.no_grad(): - outputs = model(**inputs, labels=labels) - - perplexity = torch.exp(outputs.loss) - assert perplexity <= exp_perplexity From 199f274a43951d7189dae38205aeaf676968c84f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 12 Sep 2025 12:10:03 -0400 Subject: [PATCH 39/42] activations have one row Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 21 ++++------- .../quantization/lifecycle/initialize.py | 35 +++++++------------ .../quantization/quant_args.py | 9 ++--- .../quantization/utils/helpers.py | 17 +++++++++ 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e973f39b7..850d8f1e5 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -280,17 +280,8 @@ def _process_quantization( f"by the given group_size {group_size}" ) - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has been initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - num_groups = int(ceil(columns / group_size)) - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - + # permute groups + if g_idx is not None: perm = torch.argsort(g_idx) x = x.index_select(-1, perm) @@ -299,6 +290,8 @@ def _process_quantization( ceil(x.shape[-1] / group_size), group_size, ) + # we should potentially be folding reshaped_dims[0] into x.shape[-2] + # in order to allow for multi-headed activations x = x.unflatten(-1, reshaped_dims) if do_quantize: @@ -325,9 +318,9 @@ def _process_quantization( output = output.flatten(-2, -1) output = output.to(output_dtype) - if not is_column_order: - inv_perm = torch.argsort(perm) - output = output.index_select(-1, inv_perm) + # unpermute groups + if g_idx is not None: + x = x.index_select(-1, g_idx) else: # covers channel, token and tensor strategies if do_quantize: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 37c6f19a1..cfbb42ced 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,10 +14,8 @@ import logging -import math -import warnings from enum import Enum -from typing import Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -32,7 +30,11 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import ( + is_fp4, + is_kv_cache_quant_scheme, + strict_divide, +) from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -102,7 +104,7 @@ def initialize_module_for_quantization( if scheme.input_activations is not None: base_name = "input" args = scheme.input_activations - observed_shape = weight.shape[-1:] + observed_shape = (1, weight.size(-1)) observed_dtype = weight.dtype if scheme.weights is not None: @@ -148,7 +150,7 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - observed_shape: torch.Size, + observed_shape: Tuple[int], observed_dtype: torch.dtype, force_zero_point: bool = True, ): @@ -191,8 +193,8 @@ def _initialize_scale_zero_point( raise ValueError("Group quant requires at least 1 observed dimension") group_size = quantization_args.group_size - num_groups = _strict_divide(observed_shape[-1], group_size, strategy) - expected_shape = (num_groups, group_size) + num_groups = strict_divide(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) # initialize activation ordering if applicable if actorder == ActivationOrdering.GROUP: @@ -208,8 +210,8 @@ def _initialize_scale_zero_point( raise ValueError("Block quant requires at least 2 observed dimensions") block_structure = quantization_args.block_structure - num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy) - num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy) + num_rows = strict_divide(observed_shape[-2], block_structure[-2], strategy) + num_cols = strict_divide(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) # 2. Identify quantization scale and zp dtype @@ -264,16 +266,3 @@ def _initialize_attn_scales(module: Module) -> None: requires_grad=False, ) register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) - - -def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int: - out = observed // divisor - if out * divisor != observed: - raise ValueError( - f"{strategy} quantization strategy requires strict division of " - f"weight/activation size {observed} and group/block size {divisor}. " - "consider reducing the group/block size or ignoring modules with weights " - f"not divisible by {divisor}" - ) - - return out diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index c55ee5efc..1ee294870 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -283,8 +283,9 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": has_block_structure = block_structure is not None if has_block_strategy != has_block_structure: raise ValueError( - "Block strategy requires `block_structure`, and vice versa. " - f"Instead got ({strategy}, {block_structure})" + "`strategy = block` requires `block_structure != None`, and vice versa." + f" Instead got `strategy={strategy}` and " + f"`block_structure={block_structure}`" ) # validate group strategy @@ -296,8 +297,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": has_actorder = actorder is not None if has_group_strategy != has_group_size: raise ValueError( - "Group strategies require `group_size`, and vice versa. " - f"Instead got ({strategy}, {group_size})" + "`strategy = group` requires `group_size != None`, and vice versa. " + f"Instead got `strategy={strategy}` and `group_size={group_size}`" ) if has_actorder and not has_group_strategy: raise ValueError( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..8acf69a87 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -48,6 +48,7 @@ "calculate_qparams", "generate_gparam", "is_fp4", + "strict_divide", ] # target the self_attn layer @@ -477,3 +478,19 @@ def generate_gparam( max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) + + +def strict_divide( + observed: int, divisor: int, strategy: Optional[QuantizationStrategy] = None +) -> int: + out = observed // divisor + if out * divisor != observed: + if strategy is not None: + raise ValueError( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {observed} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + + return out From d53ba36039cc941e5f862b440bb81323844a1812 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 08:23:45 -0400 Subject: [PATCH 40/42] cleanup, logging Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 4 ++-- .../quantization/quant_scheme.py | 23 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index cfbb42ced..37ccb7e86 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -104,7 +104,7 @@ def initialize_module_for_quantization( if scheme.input_activations is not None: base_name = "input" args = scheme.input_activations - observed_shape = (1, weight.size(-1)) + observed_shape = (1, weight.shape[-1]) observed_dtype = weight.dtype if scheme.weights is not None: @@ -185,7 +185,7 @@ def _initialize_scale_zero_point( if len(observed_shape) < 1: raise ValueError("Channel quant requires at least 1 observed dimension") - expected_shape = (observed_shape[-1], 1) + expected_shape = (observed_shape[-2], 1) elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): assert quantization_args.group_size is not None diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a7c3d590e..1a036e1cf 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -24,6 +24,7 @@ QuantizationType, ) from pydantic import BaseModel, ConfigDict, model_validator +from loguru import logger __all__ = [ @@ -60,15 +61,19 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": format = model.format if inputs is not None: - if inputs.strategy not in ( - QuantizationStrategy.TOKEN, - QuantizationStrategy.TENSOR, - QuantizationStrategy.GROUP, - QuantizationStrategy.TENSOR_GROUP, - ): - raise NotImplementedError( - f"Using {inputs.strategy} strategy is not supported for " - "activation quantization" + if inputs.strategy == QuantizationStrategy.CHANNEL: + raise ValueError( + "Channel-wise activation quantization is equivalent to " + "tensor/token-wise activation quantization, please use one of " + "those. If you mean to quantize each activation value " + "individually, please use group quantization with `group_size = 1`" + ) + + if inputs.strategy == QuantizationStrategy.BLOCK: + raise ValueError( + "Block-wise activation quantization is not supported. If you mean " + "to quantize each activation value individually, please use group " + "quantization with `group_size = 1`" ) if inputs.actorder is not None: From 0d860cd7e0089758f979fc741b3d06b685718e2c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 11:02:09 -0400 Subject: [PATCH 41/42] remove scheme merge Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 10 +---- .../quantization/quant_scheme.py | 38 ------------------- 2 files changed, 1 insertion(+), 47 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index a27020881..4a6895429 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -143,7 +143,7 @@ def apply_quantization_config( model, scheme.targets, config.ignore or [], warn_on_fail=True ): # attach scheme to module (with merging) - attach_scheme(submodule, scheme) + setattr(submodule, "quantization_scheme", scheme) # replace with run compressed if applicable # FUTURE: move this to model compressor @@ -175,14 +175,6 @@ def apply_quantization_config( attach_config(model, config) -def attach_scheme(module: Module, scheme: QuantizationScheme) -> QuantizationScheme: - if existing_scheme := getattr(module, "quantization_scheme", None): - scheme = scheme.merge(existing_scheme) - - setattr(module, "quantization_scheme", scheme) - return scheme - - def attach_config(model: PreTrainedModel, config: QuantizationConfig): if existing_config := getattr(model, "quantization_config", None): config = config.merge(existing_config) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 71538f156..dbc966f16 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -109,44 +109,6 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": return model - def merge(self, other: "QuantizationScheme") -> "QuantizationScheme": - def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: - if field_name == "targets": - return value_a + value_b - - if field_name == "kv_cache_only": - # nones defer to other value - if value_a is None: - return value_b - if value_b is None: - return value_a - - # kv_cache_only=True overrides - return not ((not value_a) or (not value_b)) - - if value_a is not None and value_b is None: - return value_a - - if value_a is None and value_b is not None: - return value_b - - if value_a == value_b: - return value_a - - raise ValueError( - "The following fields have overlapping targets and conflicting values " - f"for {field_name}. Please modify your config to resolve this " - f"ambiguity.\n{self}\n{other}" - ) - - dict_a = self.model_dump() - dict_b = other.model_dump() - - assert dict_a.keys() == dict_b.keys() - return self.model_validate( - {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} - ) - model_config = ConfigDict(extra="forbid") From 20744ebfc7d00b43351f562e36030af302c42261 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 13:56:47 -0400 Subject: [PATCH 42/42] use attention head quant Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 19 ++- src/compressed_tensors/modeling/kvcache.py | 27 +++- .../quantization/lifecycle/apply.py | 28 ++-- .../quantization/lifecycle/initialize.py | 128 +++++++++------- .../quantization/quant_args.py | 1 + .../quantization/quant_config.py | 141 +++++++++++------- .../quantization/quant_scheme.py | 2 +- .../transform/factory/base.py | 2 +- 8 files changed, 212 insertions(+), 136 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index 1d86e78b4..13c9827ad 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -87,8 +87,23 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul and not scheme.kv_cache_only ): # TODO: use model.config.num_attention_heads to find query_size - assert quant_args.strategy == QuantizationStrategy.TENSOR - _initialize_scale_zero_point(module, "q", quant_args) + assert quant_args.strategy in ( + QuantizationStrategy.TENSOR, + QuantizationStrategy.TOKEN, + QuantizationStrategy.ATTN_HEAD, + ) + + num_heads = model.config.num_attention_heads + hidden_size = model.config.hidden_size + observed_dtype = next(module.parameters()).dtype + _initialize_scale_zero_point( + module, + "q", + quant_args, + observed_shape=(num_heads, hidden_size), + observed_dtype=observed_dtype, + force_zero_point=True, + ) self._qparams_initialized = True diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index f26fca084..084da9085 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -77,9 +77,30 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul if not self._qparams_initialized and quant_args is not None: # TODO: use model.config.num_key_value_heads to find key_size, value_size - assert quant_args.strategy == QuantizationStrategy.TENSOR - _initialize_scale_zero_point(module, "k", quant_args) - _initialize_scale_zero_point(module, "v", quant_args) + assert quant_args.strategy in ( + QuantizationStrategy.TENSOR, + QuantizationStrategy.TOKEN, + QuantizationStrategy.ATTN_HEAD, + ) + num_heads = model.config.num_key_value_heads + hidden_size = model.config.hidden_size + observed_dtype = next(module.parameters()).dtype + _initialize_scale_zero_point( + module, + "k", + quant_args, + observed_shape=(num_heads, hidden_size), + observed_dtype=observed_dtype, + force_zero_point=True, + ) + _initialize_scale_zero_point( + module, + "v", + quant_args, + observed_shape=(num_heads, hidden_size), + observed_dtype=observed_dtype, + force_zero_point=True, + ) self._qparams_initialized = True diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 4a6895429..3091d4648 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -128,7 +128,6 @@ def apply_quantization_config( decompressed fully on load """ from compressed_tensors.linear.compressed_linear import CompressedLinear - from compressed_tensors.modeling.attention import initialize_hooked_attention config = deepcopy(config) if config is None: # see PR #180 @@ -142,7 +141,7 @@ def apply_quantization_config( for name, submodule in match_named_modules( model, scheme.targets, config.ignore or [], warn_on_fail=True ): - # attach scheme to module (with merging) + # attach scheme (overwrite any existing) setattr(submodule, "quantization_scheme", scheme) # replace with run compressed if applicable @@ -160,26 +159,15 @@ def apply_quantization_config( replace_module(model, name, compressed_linear) # attention quantization and/or kv cache quantization - if is_attention_module(submodule): - if is_narrow_match(model, scheme.targets, name): - # unlike linear, do qparam initialization here (once) - initialize_hooked_attention(model, submodule, quantize=True) - else: - # do not quantize attention unless specifically targeted - delattr(submodule, "quantization_scheme") + if is_attention_module(submodule) and not is_narrow_match( + model, scheme.targets, name + ): + # do not quantize attention unless specifically targeted + delattr(submodule, "quantization_scheme") # apply current quantization status across all targeted linear/embedding layers apply_quantization_status(model, config.quantization_status) - # attach config for serialization - attach_config(model, config) - - -def attach_config(model: PreTrainedModel, config: QuantizationConfig): - if existing_config := getattr(model, "quantization_config", None): - config = config.merge(existing_config) - setattr(model, "quantization_config", config) - def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: """ @@ -228,7 +216,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply( lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init + module, + force_zero_point=force_zero_point_init, + model=model, ) ) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5bf887986..04b066744 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -30,17 +30,14 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import ( - is_fp4, - is_kv_cache_quant_scheme, - strict_divide, -) +from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, register_offload_parameter, ) from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = [ @@ -62,6 +59,7 @@ def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, + model: Optional[PreTrainedModel] = None, ): """ attaches appropriate scales, zero points, and observers to a layer @@ -80,57 +78,68 @@ def initialize_module_for_quantization( if scheme is None: return - if not isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): - return + if is_attention_module(module): + from compressed_tensors.modeling.attention import initialize_hooked_attention + + if not isinstance(model, PreTrainedModel): + raise ValueError("Must pass model in order to initialize attention") + initialize_hooked_attention(model, module, quantize=True) - # use weight to determine observed shapes and dtype - if hasattr(module, "weight"): - weight = module.weight - assert isinstance(weight, torch.Tensor) else: - # Note that a weight is required for both weight and activation - # quantization in order to know the dtype of activation scales - _LOGGER.warning( - f"module type {type(module)} targeted for quantization but " - f"has no attribute weight, skipping quantization for {type(module)}" - ) - return + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" + ) + return + + if scheme.input_activations is not None: + _initialize_scale_zero_point( + module, + "input", + scheme.input_activations, + observed_shape=(1, weight.shape[-1]), + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) - if scheme.input_activations is not None: - base_name = "input" - args = scheme.input_activations - observed_shape = (1, weight.shape[-1]) - observed_dtype = weight.dtype - - if scheme.weights is not None: - base_name = "weight" - args = scheme.weights - observed_shape = weight.shape - observed_dtype = weight.dtype - - if scheme.output_activations is not None: - base_name = "output" - args = scheme.output_activations - observed_shape = weight.shape[:-1] - observed_dtype = weight.dtype - - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, - base_name, - args, - observed_shape=observed_shape, - observed_dtype=observed_dtype, - force_zero_point=force_zero_point, - ) + if scheme.weights is not None: + _initialize_scale_zero_point( + module, + "weight", + scheme.weights, + observed_shape=weight.shape, + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) + + output_is_kv_cache = is_kv_cache_quant_scheme(scheme) + if scheme.output_activations is not None and not output_is_kv_cache: + _initialize_scale_zero_point( + module, + "output", + scheme.output_activations, + observed_shape=weight.shape[:-1], + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED - with disable_hf_hook(module): - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): @@ -173,9 +182,12 @@ def _initialize_scale_zero_point( return # 1. Infer expected scale/zp shape - if strategy in (QuantizationStrategy.TENSOR, QuantizationStrategy.TOKEN): + if strategy == QuantizationStrategy.TENSOR: expected_shape = (1,) + elif strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + elif strategy == QuantizationStrategy.CHANNEL: if len(observed_shape) < 1: raise ValueError("Channel quant requires at least 1 observed dimension") @@ -188,7 +200,7 @@ def _initialize_scale_zero_point( raise ValueError("Group quant requires at least 1 observed dimension") group_size = quantization_args.group_size - num_groups = strict_divide(observed_shape[-1], group_size, strategy) + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) expected_shape = (*observed_shape[:-1], num_groups) # initialize activation ordering if applicable @@ -205,10 +217,16 @@ def _initialize_scale_zero_point( raise ValueError("Block quant requires at least 2 observed dimensions") block_structure = quantization_args.block_structure - num_rows = strict_divide(observed_shape[-2], block_structure[-2], strategy) - num_cols = strict_divide(observed_shape[-1], block_structure[-1], strategy) + num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) + num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + expected_shape = (observed_shape[-2], 1) + + else: + assert False, f"Unknown strategy {strategy}" + # 2. Identify quantization scale and zp dtype scale_dtype = observed_dtype diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index cf721f139..1e4e1bf41 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum): diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index ecdc8ae21..42df3a337 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -26,7 +26,7 @@ module_type, parse_out_kv_cache_args, ) -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field from torch.nn import Module @@ -58,6 +58,13 @@ class QuantizationStatus(str, Enum): FROZEN = "frozen" COMPRESSED = "compressed" + @classmethod + def lifecycle_order(cls) -> List["QuantizationStatus"]: + """ + :return: list of correct quantization lifecycle order + """ + return + def __ge__(self, other): if other is None: return True @@ -95,7 +102,7 @@ def __le__(self, other): ] DEFAULT_QUANTIZATION_METHOD = "compressed-tensors" -DEFAULT_QUANTIZATION_FORMAT = "fakequant" # TODO: remove +DEFAULT_QUANTIZATION_FORMAT = "fakequant" class QuantizationConfig(BaseModel): @@ -156,68 +163,92 @@ def to_dict(self): # for compatibility with HFQuantizer return self.model_dump() - def merge(self, other: "QuantizationConfig") -> "QuantizationConfig": - def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: - if field_name == "config_groups": - return value_a | value_b - - if field_name == "ignore": - if value_a is not None and value_b is None: - return value_a - - if value_a is None and value_b is not None: - return value_b - - if set(value_a) == set(value_b): - return value_a - - raise NotImplementedError( - "Cannot merge quantization configs with non-identical ignore lists " - "Please modify your config to resolve this ambiguity." - f"\n{self}\n{other}" - ) - - if value_a is not None and value_b is None: - return value_a - - if value_a is None and value_b is not None: - return value_b - - if value_a == value_b: - return value_a - - raise ValueError( - "The following fields have overlapping targets and conflicting values " - f"for {field_name}. Please modify your config to resolve this " - f"ambiguity.\n{self}\n{other}" - ) - - dict_a = self.model_dump() - dict_b = other.model_dump() + @staticmethod + def from_pretrained( + model: Module, format: Optional[str] = None + ) -> Optional["QuantizationConfig"]: + """ + Converts a model into its associated QuantizationConfig based on the + QuantizationScheme attached to each quantized module - assert dict_a.keys() == dict_b.keys() - return self.model_validate( - {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} + :param model: model to calculate quantization scheme of + :return: filled out QuantizationScheme for the input model + """ + quant_scheme_to_layers = [] + quantization_status = None + ignore = {} + quantization_type_names = set() + for name, submodule in model.named_modules(): + layer_type = module_type(submodule) + if not is_module_quantized(submodule): + if layer_type not in ignore: + ignore[layer_type] = [] + ignore[layer_type].append(name) + else: + quantization_status = submodule.quantization_status + scheme = submodule.quantization_scheme + quantization_type_names.add(layer_type) + + match_found = False + for existing_scheme in quant_scheme_to_layers: + if scheme == existing_scheme: + match_found = True + break + if not match_found: + quant_scheme_to_layers.append(scheme) + + if len(quant_scheme_to_layers) == 0: # No quantized layers + return None + + # kv-cache only, no weight/activation quantization + if ( + len(quantization_type_names) == 1 + and "attention" in list(quantization_type_names)[0].lower() + ): + quantization_type_names.add("Linear") + + # clean up ignore list, we can leave out layers types if none of the + # instances are quantized + consolidated_ignore = [] + for layer_type, ignore_names in ignore.items(): + if layer_type in quantization_type_names: + # specific layers of a quantized type are ignored + consolidated_ignore += ignore_names + # else we leave it off the ignore list, doesn't fall under any of the + # existing quantization schemes so it won't be quantized + + kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( + quant_scheme_to_layers + ) + kv_cache_scheme = ( + kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args ) - @classmethod - def from_pretrained( - cls, model: Module, format: Optional[str] = None - ) -> "QuantizationConfig": - default_config = QuantizationConfig(config_groups={}) - config = getattr(model, "quantization_config", default_config) + config_groups = {} + for idx, scheme in enumerate(quant_scheme_to_layers): + group_name = "group_" + str(idx) + config_groups[group_name] = scheme - # silently override format - if isinstance(format, list): + if format is None: + if quantization_status == QuantizationStatus.COMPRESSED: + format = CompressionFormat.int_quantized.value + else: + format = CompressionFormat.dense.value + elif isinstance(format, list): format = ( CompressionFormat.mixed_precision.value if len(format) > 1 else format[0] ) - if format is None: - format = CompressionFormat.dense.value - config.format = format - return config + + return QuantizationConfig( + config_groups=config_groups, + quantization_status=quantization_status, + kv_cache_scheme=kv_cache_scheme, + global_compression_ratio=None, + format=format, + ignore=consolidated_ignore, + ) def requires_calibration_data(self): if self.kv_cache_scheme is not None: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index dbc966f16..e00bb899a 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -23,8 +23,8 @@ QuantizationStrategy, QuantizationType, ) -from pydantic import BaseModel, ConfigDict, model_validator from loguru import logger +from pydantic import BaseModel, ConfigDict, model_validator __all__ = [ diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index b3f478a41..8c56ca14a 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,6 +18,7 @@ import torch import torch.nn.utils.parametrize as P +import tqdm from compressed_tensors.modeling.attention import ( initialize_hooked_attention, register_query_hook, @@ -26,7 +27,6 @@ initialize_hooked_kv_cache, register_key_hook, ) -import tqdm from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs,