From e0db8db9bdc8ed366cecbcf69f8a7f073fc5cbd0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:56:25 -0400 Subject: [PATCH 1/7] refactor Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 29 ++++++++++--------- .../quantization/utils/helpers.py | 16 ++++++++++ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..8c3c4867 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,23 +14,22 @@ import logging +from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization import ( +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, - KVCacheScaleType, QuantizationArgs, - QuantizationMetadata, - QuantizationScheme, - QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -54,17 +53,21 @@ _LOGGER = logging.getLogger(__name__) +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - Attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme. + attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme - Previously initialized scales and zero points will be removed from - module if they no longer apply to the scheme + apply to full model with `model.apply(initialize_module_for_quantization)` :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -77,8 +80,6 @@ def initialize_module_for_quantization( if scheme is None: return - QuantizationMetadata.clear_all_qparams(module) - if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c..d4428438 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,6 +33,7 @@ __all__ = [ + "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -235,6 +236,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max +def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa + """ + Checks the quantization status of a model. Assumes all modules in the model have + the same status, so only the first quantized model is checked. + + :param model: model to check quantization status for + :return: quantization status if the model is quantized, otherwise None + """ + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization From 0e58290bbd7ce87f4f735db9067cc3cfc6094c00 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:59:55 -0400 Subject: [PATCH 2/7] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8c3c4867..4b896d37 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,22 +14,23 @@ import logging -from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, + KVCacheScaleType, QuantizationArgs, + QuantizationMetadata, + QuantizationScheme, + QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -53,21 +54,17 @@ _LOGGER = logging.getLogger(__name__) -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme + Attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme. - apply to full model with `model.apply(initialize_module_for_quantization)` + Previously initialized scales and zero points will be removed from + module if they no longer apply to the scheme :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -80,6 +77,8 @@ def initialize_module_for_quantization( if scheme is None: return + QuantizationMetadata.clear_all_qparams(module) + if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) From b6d056001c421bfff23bbd1ad0637da5ef42b747 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:20:01 -0400 Subject: [PATCH 3/7] increase num of required observed dims Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..390b174a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -234,6 +234,12 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + if len(observed_shape) < 2: + raise ValueError("Attention quant requires at least 2 observed dimensions") + + expected_shape = (observed_shape[-2], 1) + else: assert False, f"Unknown strategy {strategy}" From 36bf6a5cb550d5307792a82f8f84913cfeb7e485 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:03:36 -0400 Subject: [PATCH 4/7] add tests Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353..5b6e23ee 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum): From e6d92db4d8d0b6e982ccc41fc96d77f84f97237f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:18:45 -0400 Subject: [PATCH 5/7] add tests for attn head Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c..1e3e089d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.TENSOR, QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.ATTN_HEAD, ): if ( inputs.strategy == QuantizationStrategy.GROUP From 51602c2a62947ce5a42ddcc858019a943a50c419 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:28:05 -0400 Subject: [PATCH 6/7] add tests Signed-off-by: Kyle Sayers --- tests/observer.py | 10 ++++++++++ .../lifecycle/test_static_lifecycle.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/observer.py b/tests/observer.py index 290153c0..b30d19fa 100644 --- a/tests/observer.py +++ b/tests/observer.py @@ -158,6 +158,9 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs) .unsqueeze(0) ) + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to weights") + assert False, f"Unknown strategy {args.strategy}" @@ -182,6 +185,9 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to linear acts") + assert False, f"Unknown strategy {args.strategy}" @@ -203,4 +209,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, head_dim) + return value.flatten(0, 1).unsqueeze(-2) + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 4adcba98..efc17aec 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -302,6 +302,25 @@ class MockAttention(torch.nn.Module): # group is not supported # tensor group is not supported # block is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="attn_head", + ), + torch.tensor([[0], [3]]), + torch.tensor([[8], [11]]), + torch.tensor( + [ + [ + [[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]], + [[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.16, + ), ], ) def test_static_attention_quantization( From 48875e22858287163e7c7623a77d3f990ae7500e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 18:10:31 -0400 Subject: [PATCH 7/7] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/utils/helpers.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index d4428438..fccd677c 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,7 +33,6 @@ __all__ = [ - "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -236,21 +235,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa - """ - Checks the quantization status of a model. Assumes all modules in the model have - the same status, so only the first quantized model is checked. - - :param model: model to check quantization status for - :return: quantization status if the model is quantized, otherwise None - """ - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization