From 03432a91502dcc5337e2993a981a1b2973fcfeec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 21:01:44 -0400 Subject: [PATCH 01/20] better type hints, warn once Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 44adc7ef..c2bbaeda 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -195,7 +195,7 @@ def decorator(func: T) -> T: @wraps(func) def wrapped(*args, **kwargs): - warnings.warn(message, DeprecationWarning, stacklevel=2) + logger.bind(log_once=True).warning(message) return func(*args, **kwargs) return wrapped From 3c9e4990f132d6e89de3f2337dc85a32b25166d3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 21:06:15 -0400 Subject: [PATCH 02/20] remove unneeded import Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index c2bbaeda..5324cd28 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,7 +13,6 @@ # limitations under the License. import contextlib -import warnings from functools import wraps from types import MappingProxyType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar From 390a4675e3472189fa31e7a6d85bbd6f3457d0d1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 19:02:10 -0400 Subject: [PATCH 03/20] allow-group-dynamic-quantization Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c..a7c3d590 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -66,15 +66,6 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): - if ( - inputs.strategy == QuantizationStrategy.GROUP - and inputs.dynamic is True - ): - raise NotImplementedError( - "Static and local group-wise activation " - "quantization is not supported" - ) - raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" From 2a3feeda08ea07bff4d1bef3f1b12b54e4dbf507 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 19:08:11 -0400 Subject: [PATCH 04/20] satisfy quality checker Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a7c3d590..17dab844 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -66,6 +66,17 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): +<<<<<<< HEAD +======= + if ( + inputs.strategy == QuantizationStrategy.GROUP + and inputs.dynamic is True + ): + raise NotImplementedError( + "Static and local group-wise quantization is not supported" + ) + +>>>>>>> db46c84 (satisfy quality checker) raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" From b849d134e598efaf9a0932e813c353fc300aae6e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 19:08:57 -0400 Subject: [PATCH 05/20] more clear Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 17dab844..505bdb2a 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -73,7 +73,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": and inputs.dynamic is True ): raise NotImplementedError( - "Static and local group-wise quantization is not supported" + "Static and local group-wise activation " + "quantization is not supported" ) >>>>>>> db46c84 (satisfy quality checker) From 9208fdaf5ddf2ced5cc9caf5349a82dd6bf315a2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 20:55:27 -0400 Subject: [PATCH 06/20] basic support Signed-off-by: Kyle Sayers --- src/compressed_tensors/__init__.py | 1 + .../quantization/lifecycle/forward.py | 5 +-- .../lifecycle/test_forward.py | 39 +++++++++++++++++-- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/__init__.py b/src/compressed_tensors/__init__.py index c892e81a..08b0dfb7 100644 --- a/src/compressed_tensors/__init__.py +++ b/src/compressed_tensors/__init__.py @@ -19,6 +19,7 @@ from .compressors import * from .config import * +from .logger import LoggerConfig, configure_logger, logger from .quantization import QuantizationConfig, QuantizationStatus from .utils import * from .version import * diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2e539b07..e973f39b 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -264,8 +264,7 @@ def _process_quantization( ): output_dtype = dtype if dtype is not None else x.dtype - output = torch.zeros_like(x).to(output_dtype) - columns = output.shape[-1] + columns = x.size(-1) # TODO: make validation step for inputs @@ -323,7 +322,7 @@ def _process_quantization( global_scale=global_scale, ) - output = output.flatten(start_dim=-2) + output = output.flatten(-2, -1) output = output.to(output_dtype) if not is_column_order: diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index f3321cd4..09010af0 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -95,7 +95,7 @@ def test_forward_quantize( @pytest.mark.parametrize( - "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale", + "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale,batch_size", [ ( 4, @@ -106,6 +106,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 4, @@ -116,6 +117,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 4, @@ -126,6 +128,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -136,6 +139,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 8, @@ -146,6 +150,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -156,6 +161,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -166,6 +172,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -176,17 +183,41 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, + ), + ( + 8, + "int", + QuantizationStrategy.GROUP, + 128, + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), + make_dummy_g_idx(1024, 128), + None, + 5, ), ], ) -def test_fake_quantize_2d( - num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale +def test_fake_quantize( + num_bits, + type, + strategy, + group_size, + scale, + zero_point, + g_idx, + global_scale, + batch_size, ): args = QuantizationArgs( num_bits=num_bits, type=type, strategy=strategy, group_size=group_size ) - x = torch.rand((512, 1024)) + if batch_size is None: + x = torch.rand((512, 1024)) + else: + x = torch.rand((batch_size, 512, 1024)) + fake_quantize( x=x, scale=scale, From 6e926b1501d738f75cf69669ddcbe763991fc031 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 8 Sep 2025 21:10:36 -0400 Subject: [PATCH 07/20] fix merge Signed-off-by: Kyle Sayers --- src/compressed_tensors/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/__init__.py b/src/compressed_tensors/__init__.py index 08b0dfb7..c892e81a 100644 --- a/src/compressed_tensors/__init__.py +++ b/src/compressed_tensors/__init__.py @@ -19,7 +19,6 @@ from .compressors import * from .config import * -from .logger import LoggerConfig, configure_logger, logger from .quantization import QuantizationConfig, QuantizationStatus from .utils import * from .version import * From c959bf9b6231c986d4184cccb9a13fbbeba13efb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 9 Sep 2025 09:04:16 -0400 Subject: [PATCH 08/20] ungate group activation quant Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 505bdb2a..a7c3d590 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -66,18 +66,6 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): -<<<<<<< HEAD -======= - if ( - inputs.strategy == QuantizationStrategy.GROUP - and inputs.dynamic is True - ): - raise NotImplementedError( - "Static and local group-wise activation " - "quantization is not supported" - ) - ->>>>>>> db46c84 (satisfy quality checker) raise NotImplementedError( f"Using {inputs.strategy} strategy is not supported for " "activation quantization" From 86d504cf4b0ae83f6714d51715495bd268e3c9e2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 9 Sep 2025 13:02:09 -0400 Subject: [PATCH 09/20] refactor Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 3 +- .../quantization/lifecycle/initialize.py | 216 +++++++++--------- .../quantization/quant_args.py | 40 ++-- 3 files changed, 139 insertions(+), 120 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index faa48df2..63ddfc9d 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -221,7 +221,8 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply( lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init + module, + force_zero_point=force_zero_point_init, ) ) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5350b4a2..57009f93 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -26,6 +26,7 @@ from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, + DynamicType, QuantizationArgs, QuantizationStrategy, ) @@ -73,10 +74,8 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return if is_attention_module(module): @@ -84,38 +83,49 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: - if scheme.input_activations is not None: - _initialize_scale_zero_point( - module, - "input", - scheme.input_activations, - force_zero_point=force_zero_point, + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" ) + return + + if scheme.input_activations is not None: + base_name = "input" + args = scheme.input_activations + observed_shape = weight.shape[-1:] + observed_dtype = weight.dtype if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) + base_name = "weight" + args = scheme.weights + observed_shape = weight.shape + observed_dtype = weight.dtype if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations - ) + base_name = "output" + args = scheme.output_activations + observed_shape = weight.shape[:-1] + observed_dtype = weight.dtype + + if not is_kv_cache_quant_scheme(scheme): + _initialize_scale_zero_point( + module, + base_name, + args, + observed_shape=observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -138,18 +148,21 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, + observed_shape: torch.Size, + observed_dtype: torch.dtype, force_zero_point: bool = True, ): - if quantization_args.dynamic is True: - return + strategy = quantization_args.strategy + dynamic = quantization_args.dynamic + actorder = quantization_args.actorder + device = get_execution_device(module) # avoid performing intialization ops on cpu - # initialize on execution device to avoid performing quantized ops on cpu - device = get_execution_device(module) + # Skip all intialization for fully dynamic quantization + if dynamic is True: + return - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + # 0. Create global scale for tensor-group quantization + if strategy == QuantizationStrategy.TENSOR_GROUP: init_global_scale = Parameter( torch.empty(1, dtype=torch.float32, device=device), requires_grad=False, @@ -158,56 +171,54 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # 2. Infer expected scale/zero point shape - if quantization_args.strategy == QuantizationStrategy.TOKEN: - expected_shape = (1, 1) - else: - expected_shape = 1 - - if base_name == "weight" and weight_shape is not None: - if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) - only for weights - expected_shape = (weight_shape[0], 1) - elif quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - # GROUP/TENSOR_GROUP for both weights and activations - num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) - expected_shape = (weight_shape[0], max(num_groups, 1)) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only - # for weights - if quantization_args.block_structure is None: - raise ValueError( - "Block quantization requires block_structure to be specified" - ) - block_height, block_width = quantization_args.block_structure - rows, cols = weight_shape[-2], weight_shape[-1] - num_rows_blocks = math.ceil(rows / block_height) - num_cols_blocks = math.ceil(cols / block_width) - - # Warn if dimensions don't divide evenly - if rows % block_height != 0 or cols % block_width != 0: - warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide" - f"evenly by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization" - "quality.", - UserWarning, - ) - - expected_shape = (num_rows_blocks, num_cols_blocks) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - warnings.warn( - f"BLOCK quantization not supported for {base_name} activations. " - f"Falling back to tensor-level quantization.", - UserWarning, - ) - expected_shape = 1 + # Skip scale/zp initialization for locally dynamic quantization + if dynamic == DynamicType.LOCAL: + return + + # 1. Infer expected scale/zp shape + if strategy in (QuantizationStrategy.TENSOR, QuantizationStrategy.TOKEN): + expected_shape = (1,) + + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 1: + raise ValueError("Channel quant requires at least 1 observed dimension") + expected_shape = (observed_shape[-1], 1) + +<<<<<<< HEAD # 3. Identify quantization scale and zp dtype scale_dtype = module.weight.dtype +======= + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = _strict_divide(observed_shape[-1], group_size, strategy) + expected_shape = (num_groups, group_size) + + # initialize activation ordering if applicable + if actorder == ActivationOrdering.GROUP: + init_g_idx = Parameter( + torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + elif strategy == QuantizationStrategy.BLOCK: + assert quantization_args.block_structure is not None + if len(observed_shape) < 2: + raise ValueError("Block quant requires at least 2 observed dimensions") + + block_structure = quantization_args.block_structure + num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy) + num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy) + expected_shape = (num_rows, num_cols) + + # 2. Identify quantization scale and zp dtype + scale_dtype = observed_dtype +>>>>>>> fde779c (refactor) if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -223,14 +234,12 @@ def _initialize_scale_zero_point( scale_dtype = torch.bfloat16 zp_dtype = quantization_args.pytorch_dtype() - # 4. Initializes empty scale, zero point, and g_idx parameters for the module - # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # 3. Initializes scale/zp for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( @@ -239,16 +248,6 @@ def _initialize_scale_zero_point( ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) - # only grouped activation ordering has g_idx - if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" @@ -270,3 +269,16 @@ def _initialize_attn_scales(module: Module) -> None: requires_grad=False, ) register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) + + +def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int: + out = observed // divisor + if out * divisor != observed: + raise ValueError( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {observed} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with weights " + f"not divisible by {divisor}" + ) + + return out diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353..c55ee5ef 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -262,6 +262,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": actorder = model.actorder dynamic = model.dynamic observer = model.observer + block_structure = model.block_structure # infer strategy if strategy is None: @@ -277,23 +278,28 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": "strategy='group' and group_size = -1 for 'channel'" ) - # validate strategy and group - if strategy == QuantizationStrategy.GROUP: - if group_size is None or group_size <= 0: - raise ValueError( - f"strategy {strategy} requires group_size to be " - "set to a positive value" - ) - if ( - group_size is not None - and group_size > 0 - and strategy - not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP) - ): - raise ValueError("group_size requires strategy to be set to 'group'") - - # validate activation ordering and strategy - if actorder is not None and strategy != QuantizationStrategy.GROUP: + # validate block strategy and structure + has_block_strategy = strategy == QuantizationStrategy.BLOCK + has_block_structure = block_structure is not None + if has_block_strategy != has_block_structure: + raise ValueError( + "Block strategy requires `block_structure`, and vice versa. " + f"Instead got ({strategy}, {block_structure})" + ) + + # validate group strategy + has_group_strategy = strategy in ( + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ) + has_group_size = group_size is not None and group_size > 0 + has_actorder = actorder is not None + if has_group_strategy != has_group_size: + raise ValueError( + "Group strategies require `group_size`, and vice versa. " + f"Instead got ({strategy}, {group_size})" + ) + if has_actorder and not has_group_strategy: raise ValueError( "Must use group quantization strategy in order to apply " "activation ordering" From 3c91e5b0d8ce565f6358a07a9a40c2cf633f10e7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 11 Sep 2025 15:25:36 -0400 Subject: [PATCH 10/20] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 3 +- .../quantization/lifecycle/initialize.py | 5 -- src/compressed_tensors/transform/apply.py | 35 ------------ .../transform/factory/base.py | 38 ++++++++++++- .../transform/factory/hadamard.py | 4 +- .../transform/utils/hadamard.py | 7 +-- src/compressed_tensors/utils/helpers.py | 7 ++- tests/test_transform/conftest.py | 2 - .../factory/test_serialization.py | 54 ++----------------- 9 files changed, 51 insertions(+), 104 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 63ddfc9d..faa48df2 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -221,8 +221,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply( lambda module: initialize_module_for_quantization( - module, - force_zero_point=force_zero_point_init, + module, force_zero_point=force_zero_point_init ) ) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 57009f93..37c6f19a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -185,10 +185,6 @@ def _initialize_scale_zero_point( expected_shape = (observed_shape[-1], 1) -<<<<<<< HEAD - # 3. Identify quantization scale and zp dtype - scale_dtype = module.weight.dtype -======= elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): assert quantization_args.group_size is not None if len(observed_shape) < 1: @@ -218,7 +214,6 @@ def _initialize_scale_zero_point( # 2. Identify quantization scale and zp dtype scale_dtype = observed_dtype ->>>>>>> fde779c (refactor) if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index 28d5e94f..e247e702 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict - import torch -from accelerate.utils import has_offloaded_params from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -37,35 +34,3 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): # attach config to model for compression/serialization setattr(model, TRANSFORM_CONFIG_NAME, config) - - # ensure that tied weight transforms can be serialized without aliases - # In the future, this could be done by transformers or model compressor - # which would make this more robust to changing dispatches after transforms - _tie_offloaded_tensors(model) - - -def _tie_offloaded_tensors(model: torch.nn.Module): - """ - When accelerate replaces tensors with meta tensors during offloading, the meta - tensors may not be identical, even if the offloaded values are identical. - - However, transformers can only serialize correctly if meta tensors are identical - (see transformers#39263). - - This function collects all meta tensors which have shared offloaded values and sets - those tensors to be identical so that they can be removed during serialization - - :param model: model potentially containing offloaded meta tensors to fix - """ - - # ensure that if a location shares an offloaded tensor pointers, that the - # meta tensor is also identical (assigned to the first instance of parameter) - ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() - for module in model.modules(): - if has_offloaded_params(module): - for key, _ in module.named_parameters(recurse=False): - offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() - - if offloaded_ptr not in ptr_to_meta: - ptr_to_meta[offloaded_ptr] = getattr(module, key) - setattr(module, key, ptr_to_meta[offloaded_ptr]) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 34d609e7..94e6b4a4 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,7 +13,8 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import List, Optional +from collections import defaultdict +from typing import List, Optional, Set, Tuple import torch import torch.nn.utils.parametrize as P @@ -56,6 +57,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non self.name = name self.scheme = scheme self.generator = torch.Generator() + self.transforms = list() if seed is not None: self.generator.manual_seed(seed) @@ -99,6 +101,8 @@ def apply_to_model(self, model: Module, use_tqdm=True): for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): self._apply_to_module(module, arg) + self._update_tied_weights() + def _apply_to_module(self, module: Module, args: TransformArgs): """ Create transforms and apply them to the module @@ -116,6 +120,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) + self.transforms.append(transform) register_offload_module(module, transform_name, transform) # register input transformation hook @@ -160,6 +165,31 @@ def output_hook(_, _input, output): else: raise NotImplementedError() + def _update_tied_weights(self): + """ + Populate the `_dynamic_tied_weights_keys` attribute of transforms, + which is used by transformers to detect and remove shared pointers + during saving + """ + # map from data_ptrs to keys + ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list) + for transform in self.transforms: + for name, param in transform.named_parameters(recurse=False): + # NOTE: previously asserted that parent._hf_hook.place_submodules=False + if has_offloaded_params(transform): + param = transform._hf_hook.weights_map[name] + ptr_to_keys[param.data_ptr()].append((transform, name)) + + # populate `_dynamic_tied_weights_keys` if there is more than one key + # and ensure that they share tensors + for shared_keys in ptr_to_keys.values(): + if len(shared_keys) > 1: + tensor = getattr(shared_keys[0][0], shared_keys[0][1]) + + for transform, name in shared_keys: + transform._dynamic_tied_weights_keys.add(name) + setattr(transform, name, tensor) + class TransformBase(InternalModule, ABC): """ @@ -168,7 +198,11 @@ class TransformBase(InternalModule, ABC): args: TransformArgs weight: Parameter - _dynamic_tied_weights_keys: List[str] = ["weight"] + _dynamic_tied_weights_keys: Set[str] + + def __init__(self): + super().__init__() + self._dynamic_tied_weights_keys = set() @abstractmethod def forward(self, value: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 24cb50fa..2cd7be0d 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -84,8 +84,6 @@ def _create_permutation(self, weight: Parameter) -> Parameter: class HadamardTransform(TransformBase): - _dynamic_tied_weights_keys: List[str] = ["weight", "perm"] - def __init__( self, weight: Parameter, diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index c8144ae2..7d361e59 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -115,16 +115,13 @@ def _fetch_hadamard_divisor( than forcing callers to manage the file open context :param n: size of known hadamard matrix - :param dtype: data type to move fetched hadamard to - :param device: device to move fetched hadamard to :return: a known hadamard matrix of size `n` if one exists, else None """ - open_device = torch.device("cpu") if device.type == "meta" else device - with safe_open(file_path, framework="pt", device=str(open_device)) as file: + with safe_open(file_path, framework="pt", device=str(device)) as file: divisors = sorted((int(key) for key in file.keys()), reverse=True) for divisor in divisors: if n % divisor == 0 and is_pow2(n // divisor): - return file.get_tensor(str(divisor)).to(dtype=dtype, device=device) + return file.get_tensor(str(divisor)).to(dtype=dtype) return None diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 5324cd28..38a177f7 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,12 +13,17 @@ # limitations under the License. import contextlib +import warnings from functools import wraps from types import MappingProxyType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar import numpy import torch +<<<<<<< HEAD +======= +from frozendict import frozendict +>>>>>>> 6672617 (reduce diff) from transformers import AutoConfig @@ -194,7 +199,7 @@ def decorator(func: T) -> T: @wraps(func) def wrapped(*args, **kwargs): - logger.bind(log_once=True).warning(message) + warnings.warn(message, DeprecationWarning, stacklevel=2) return func(*args, **kwargs) return wrapped diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index 824c06bd..a0188c42 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -19,8 +19,6 @@ class TransformableModel(PreTrainedModel): - config_class = PretrainedConfig - def __init__(self, *sizes): super().__init__(config=PretrainedConfig()) self.fcs = torch.nn.ModuleList( diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py index 15fa240b..a688c2cf 100644 --- a/tests/test_transform/factory/test_serialization.py +++ b/tests/test_transform/factory/test_serialization.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest import torch from compressed_tensors.transform import ( @@ -22,9 +20,7 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch -from safetensors import safe_open from tests.testing_utils import requires_accelerate, requires_gpu -from transformers import AutoModelForCausalLM, AutoTokenizer @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @@ -42,57 +38,17 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False): apply_transform_config(model, config) # save model - model_path = os.path.join(tmp_path, "test_model_path") - model.save_pretrained(model_path) - - # check that saved values match model values - # note that shared weights are only serialized once - safetensors_path = os.path.join(model_path, "model.safetensors") - with safe_open(safetensors_path, framework="pt", device="cpu") as file: - saved_keys = set(file.keys()) - assert { - "fcs.0.weight", - "fcs.1.weight", - "fcs.2.weight", - "fcs.3.weight", - "fcs.4.weight", - } <= saved_keys - for key in saved_keys: - param = model.get_parameter(key) - saved_param = file.get_tensor(key) + model.save_pretrained(tmp_path) - if param.device.type != "meta": # skip testing values in offload case - assert torch.equal(param, saved_param) + # TODO: reload model +@pytest.mark.skip(reason="Requires changes in upstream transformers") +# https://github.com/huggingface/transformers/pull/39280 +# https://github.com/huggingface/transformers/pull/39263 @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomize", (True, False)) def test_serialization_offload(type, randomize, model_apply, tmp_path): test_serialization(type, randomize, model_apply, tmp_path, offload=True) - - -@pytest.mark.skip("Requires transformers#40673") -@requires_gpu -@pytest.mark.parametrize( - "model_stub,exp_perplexity", - [ - ("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0), - ("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0), - ], -) -def test_load_perplexity(model_stub, exp_perplexity): - model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda") - tokenizer = AutoTokenizer.from_pretrained(model_stub) - - prompt = "The capital of France is Paris, the capital of Germany is Berlin" - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {key: value.to(model.device) for key, value in inputs.items()} - labels = inputs["input_ids"] - - with torch.no_grad(): - outputs = model(**inputs, labels=labels) - - perplexity = torch.exp(outputs.loss) - assert perplexity <= exp_perplexity From e75195bbc06938ff7a6d1663f6cc5510b3ffdedc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 12 Sep 2025 12:10:03 -0400 Subject: [PATCH 11/20] activations have one row Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 21 ++++------- .../quantization/lifecycle/initialize.py | 35 +++++++------------ .../quantization/quant_args.py | 9 ++--- .../quantization/utils/helpers.py | 17 +++++++++ 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e973f39b..850d8f1e 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -280,17 +280,8 @@ def _process_quantization( f"by the given group_size {group_size}" ) - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has been initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - num_groups = int(ceil(columns / group_size)) - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - + # permute groups + if g_idx is not None: perm = torch.argsort(g_idx) x = x.index_select(-1, perm) @@ -299,6 +290,8 @@ def _process_quantization( ceil(x.shape[-1] / group_size), group_size, ) + # we should potentially be folding reshaped_dims[0] into x.shape[-2] + # in order to allow for multi-headed activations x = x.unflatten(-1, reshaped_dims) if do_quantize: @@ -325,9 +318,9 @@ def _process_quantization( output = output.flatten(-2, -1) output = output.to(output_dtype) - if not is_column_order: - inv_perm = torch.argsort(perm) - output = output.index_select(-1, inv_perm) + # unpermute groups + if g_idx is not None: + x = x.index_select(-1, g_idx) else: # covers channel, token and tensor strategies if do_quantize: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 37c6f19a..cfbb42ce 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,10 +14,8 @@ import logging -import math -import warnings from enum import Enum -from typing import Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -32,7 +30,11 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import ( + is_fp4, + is_kv_cache_quant_scheme, + strict_divide, +) from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -102,7 +104,7 @@ def initialize_module_for_quantization( if scheme.input_activations is not None: base_name = "input" args = scheme.input_activations - observed_shape = weight.shape[-1:] + observed_shape = (1, weight.size(-1)) observed_dtype = weight.dtype if scheme.weights is not None: @@ -148,7 +150,7 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - observed_shape: torch.Size, + observed_shape: Tuple[int], observed_dtype: torch.dtype, force_zero_point: bool = True, ): @@ -191,8 +193,8 @@ def _initialize_scale_zero_point( raise ValueError("Group quant requires at least 1 observed dimension") group_size = quantization_args.group_size - num_groups = _strict_divide(observed_shape[-1], group_size, strategy) - expected_shape = (num_groups, group_size) + num_groups = strict_divide(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) # initialize activation ordering if applicable if actorder == ActivationOrdering.GROUP: @@ -208,8 +210,8 @@ def _initialize_scale_zero_point( raise ValueError("Block quant requires at least 2 observed dimensions") block_structure = quantization_args.block_structure - num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy) - num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy) + num_rows = strict_divide(observed_shape[-2], block_structure[-2], strategy) + num_cols = strict_divide(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) # 2. Identify quantization scale and zp dtype @@ -264,16 +266,3 @@ def _initialize_attn_scales(module: Module) -> None: requires_grad=False, ) register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) - - -def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int: - out = observed // divisor - if out * divisor != observed: - raise ValueError( - f"{strategy} quantization strategy requires strict division of " - f"weight/activation size {observed} and group/block size {divisor}. " - "consider reducing the group/block size or ignoring modules with weights " - f"not divisible by {divisor}" - ) - - return out diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index c55ee5ef..1ee29487 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -283,8 +283,9 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": has_block_structure = block_structure is not None if has_block_strategy != has_block_structure: raise ValueError( - "Block strategy requires `block_structure`, and vice versa. " - f"Instead got ({strategy}, {block_structure})" + "`strategy = block` requires `block_structure != None`, and vice versa." + f" Instead got `strategy={strategy}` and " + f"`block_structure={block_structure}`" ) # validate group strategy @@ -296,8 +297,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": has_actorder = actorder is not None if has_group_strategy != has_group_size: raise ValueError( - "Group strategies require `group_size`, and vice versa. " - f"Instead got ({strategy}, {group_size})" + "`strategy = group` requires `group_size != None`, and vice versa. " + f"Instead got `strategy={strategy}` and `group_size={group_size}`" ) if has_actorder and not has_group_strategy: raise ValueError( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d54519..8acf69a8 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -48,6 +48,7 @@ "calculate_qparams", "generate_gparam", "is_fp4", + "strict_divide", ] # target the self_attn layer @@ -477,3 +478,19 @@ def generate_gparam( max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) + + +def strict_divide( + observed: int, divisor: int, strategy: Optional[QuantizationStrategy] = None +) -> int: + out = observed // divisor + if out * divisor != observed: + if strategy is not None: + raise ValueError( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {observed} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + + return out From 3de48fe7efda16556270a1258ceead482a0dedf1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 08:23:45 -0400 Subject: [PATCH 12/20] cleanup, logging Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 4 ++-- .../quantization/quant_scheme.py | 23 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index cfbb42ce..37ccb7e8 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -104,7 +104,7 @@ def initialize_module_for_quantization( if scheme.input_activations is not None: base_name = "input" args = scheme.input_activations - observed_shape = (1, weight.size(-1)) + observed_shape = (1, weight.shape[-1]) observed_dtype = weight.dtype if scheme.weights is not None: @@ -185,7 +185,7 @@ def _initialize_scale_zero_point( if len(observed_shape) < 1: raise ValueError("Channel quant requires at least 1 observed dimension") - expected_shape = (observed_shape[-1], 1) + expected_shape = (observed_shape[-2], 1) elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): assert quantization_args.group_size is not None diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a7c3d590..1a036e1c 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -24,6 +24,7 @@ QuantizationType, ) from pydantic import BaseModel, ConfigDict, model_validator +from loguru import logger __all__ = [ @@ -60,15 +61,19 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": format = model.format if inputs is not None: - if inputs.strategy not in ( - QuantizationStrategy.TOKEN, - QuantizationStrategy.TENSOR, - QuantizationStrategy.GROUP, - QuantizationStrategy.TENSOR_GROUP, - ): - raise NotImplementedError( - f"Using {inputs.strategy} strategy is not supported for " - "activation quantization" + if inputs.strategy == QuantizationStrategy.CHANNEL: + raise ValueError( + "Channel-wise activation quantization is equivalent to " + "tensor/token-wise activation quantization, please use one of " + "those. If you mean to quantize each activation value " + "individually, please use group quantization with `group_size = 1`" + ) + + if inputs.strategy == QuantizationStrategy.BLOCK: + raise ValueError( + "Block-wise activation quantization is not supported. If you mean " + "to quantize each activation value individually, please use group " + "quantization with `group_size = 1`" ) if inputs.actorder is not None: From 3fe08b82e86e5fd87a216ca048c82e64a7ce0f45 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 11:42:19 -0400 Subject: [PATCH 13/20] fix merge Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/apply.py | 35 ++++++++++++ .../transform/factory/base.py | 38 +------------ .../transform/factory/hadamard.py | 4 +- .../transform/utils/hadamard.py | 7 ++- src/compressed_tensors/utils/helpers.py | 4 -- tests/test_transform/conftest.py | 2 + .../factory/test_serialization.py | 54 +++++++++++++++++-- 7 files changed, 96 insertions(+), 48 deletions(-) diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py index e247e702..28d5e94f 100644 --- a/src/compressed_tensors/transform/apply.py +++ b/src/compressed_tensors/transform/apply.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + import torch +from accelerate.utils import has_offloaded_params from compressed_tensors import TRANSFORM_CONFIG_NAME from compressed_tensors.transform import TransformConfig, TransformFactory @@ -34,3 +37,35 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig): # attach config to model for compression/serialization setattr(model, TRANSFORM_CONFIG_NAME, config) + + # ensure that tied weight transforms can be serialized without aliases + # In the future, this could be done by transformers or model compressor + # which would make this more robust to changing dispatches after transforms + _tie_offloaded_tensors(model) + + +def _tie_offloaded_tensors(model: torch.nn.Module): + """ + When accelerate replaces tensors with meta tensors during offloading, the meta + tensors may not be identical, even if the offloaded values are identical. + + However, transformers can only serialize correctly if meta tensors are identical + (see transformers#39263). + + This function collects all meta tensors which have shared offloaded values and sets + those tensors to be identical so that they can be removed during serialization + + :param model: model potentially containing offloaded meta tensors to fix + """ + + # ensure that if a location shares an offloaded tensor pointers, that the + # meta tensor is also identical (assigned to the first instance of parameter) + ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() + for module in model.modules(): + if has_offloaded_params(module): + for key, _ in module.named_parameters(recurse=False): + offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() + + if offloaded_ptr not in ptr_to_meta: + ptr_to_meta[offloaded_ptr] = getattr(module, key) + setattr(module, key, ptr_to_meta[offloaded_ptr]) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 94e6b4a4..34d609e7 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,8 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections import defaultdict -from typing import List, Optional, Set, Tuple +from typing import List, Optional import torch import torch.nn.utils.parametrize as P @@ -57,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non self.name = name self.scheme = scheme self.generator = torch.Generator() - self.transforms = list() if seed is not None: self.generator.manual_seed(seed) @@ -101,8 +99,6 @@ def apply_to_model(self, model: Module, use_tqdm=True): for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): self._apply_to_module(module, arg) - self._update_tied_weights() - def _apply_to_module(self, module: Module, args: TransformArgs): """ Create transforms and apply them to the module @@ -120,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs): # create transform as submodule transform_name = f"{self.name}_{args.location}" transform = self.create_transform(module, args) - self.transforms.append(transform) register_offload_module(module, transform_name, transform) # register input transformation hook @@ -165,31 +160,6 @@ def output_hook(_, _input, output): else: raise NotImplementedError() - def _update_tied_weights(self): - """ - Populate the `_dynamic_tied_weights_keys` attribute of transforms, - which is used by transformers to detect and remove shared pointers - during saving - """ - # map from data_ptrs to keys - ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list) - for transform in self.transforms: - for name, param in transform.named_parameters(recurse=False): - # NOTE: previously asserted that parent._hf_hook.place_submodules=False - if has_offloaded_params(transform): - param = transform._hf_hook.weights_map[name] - ptr_to_keys[param.data_ptr()].append((transform, name)) - - # populate `_dynamic_tied_weights_keys` if there is more than one key - # and ensure that they share tensors - for shared_keys in ptr_to_keys.values(): - if len(shared_keys) > 1: - tensor = getattr(shared_keys[0][0], shared_keys[0][1]) - - for transform, name in shared_keys: - transform._dynamic_tied_weights_keys.add(name) - setattr(transform, name, tensor) - class TransformBase(InternalModule, ABC): """ @@ -198,11 +168,7 @@ class TransformBase(InternalModule, ABC): args: TransformArgs weight: Parameter - _dynamic_tied_weights_keys: Set[str] - - def __init__(self): - super().__init__() - self._dynamic_tied_weights_keys = set() + _dynamic_tied_weights_keys: List[str] = ["weight"] @abstractmethod def forward(self, value: Tensor) -> Tensor: diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 2cd7be0d..24cb50fa 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -84,6 +84,8 @@ def _create_permutation(self, weight: Parameter) -> Parameter: class HadamardTransform(TransformBase): + _dynamic_tied_weights_keys: List[str] = ["weight", "perm"] + def __init__( self, weight: Parameter, diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 7d361e59..c8144ae2 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -115,13 +115,16 @@ def _fetch_hadamard_divisor( than forcing callers to manage the file open context :param n: size of known hadamard matrix + :param dtype: data type to move fetched hadamard to + :param device: device to move fetched hadamard to :return: a known hadamard matrix of size `n` if one exists, else None """ - with safe_open(file_path, framework="pt", device=str(device)) as file: + open_device = torch.device("cpu") if device.type == "meta" else device + with safe_open(file_path, framework="pt", device=str(open_device)) as file: divisors = sorted((int(key) for key in file.keys()), reverse=True) for divisor in divisors: if n % divisor == 0 and is_pow2(n // divisor): - return file.get_tensor(str(divisor)).to(dtype=dtype) + return file.get_tensor(str(divisor)).to(dtype=dtype, device=device) return None diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 38a177f7..44adc7ef 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -20,10 +20,6 @@ import numpy import torch -<<<<<<< HEAD -======= -from frozendict import frozendict ->>>>>>> 6672617 (reduce diff) from transformers import AutoConfig diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index a0188c42..824c06bd 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -19,6 +19,8 @@ class TransformableModel(PreTrainedModel): + config_class = PretrainedConfig + def __init__(self, *sizes): super().__init__(config=PretrainedConfig()) self.fcs = torch.nn.ModuleList( diff --git a/tests/test_transform/factory/test_serialization.py b/tests/test_transform/factory/test_serialization.py index a688c2cf..15fa240b 100644 --- a/tests/test_transform/factory/test_serialization.py +++ b/tests/test_transform/factory/test_serialization.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from compressed_tensors.transform import ( @@ -20,7 +22,9 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch +from safetensors import safe_open from tests.testing_utils import requires_accelerate, requires_gpu +from transformers import AutoModelForCausalLM, AutoTokenizer @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @@ -38,17 +42,57 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False): apply_transform_config(model, config) # save model - model.save_pretrained(tmp_path) + model_path = os.path.join(tmp_path, "test_model_path") + model.save_pretrained(model_path) + + # check that saved values match model values + # note that shared weights are only serialized once + safetensors_path = os.path.join(model_path, "model.safetensors") + with safe_open(safetensors_path, framework="pt", device="cpu") as file: + saved_keys = set(file.keys()) + assert { + "fcs.0.weight", + "fcs.1.weight", + "fcs.2.weight", + "fcs.3.weight", + "fcs.4.weight", + } <= saved_keys + for key in saved_keys: + param = model.get_parameter(key) + saved_param = file.get_tensor(key) - # TODO: reload model + if param.device.type != "meta": # skip testing values in offload case + assert torch.equal(param, saved_param) -@pytest.mark.skip(reason="Requires changes in upstream transformers") -# https://github.com/huggingface/transformers/pull/39280 -# https://github.com/huggingface/transformers/pull/39263 @requires_gpu @requires_accelerate() @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomize", (True, False)) def test_serialization_offload(type, randomize, model_apply, tmp_path): test_serialization(type, randomize, model_apply, tmp_path, offload=True) + + +@pytest.mark.skip("Requires transformers#40673") +@requires_gpu +@pytest.mark.parametrize( + "model_stub,exp_perplexity", + [ + ("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0), + ("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0), + ], +) +def test_load_perplexity(model_stub, exp_perplexity): + model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(model_stub) + + prompt = "The capital of France is Paris, the capital of Germany is Berlin" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {key: value.to(model.device) for key, value in inputs.items()} + labels = inputs["input_ids"] + + with torch.no_grad(): + outputs = model(**inputs, labels=labels) + + perplexity = torch.exp(outputs.loss) + assert perplexity <= exp_perplexity From 6e6524a8e8ce601411494183851002f29cff90b3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 05:54:30 -0400 Subject: [PATCH 14/20] fix typo Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 850d8f1e..83bc7fc2 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -320,7 +320,7 @@ def _process_quantization( # unpermute groups if g_idx is not None: - x = x.index_select(-1, g_idx) + output = output.index_select(-1, g_idx) else: # covers channel, token and tensor strategies if do_quantize: From 6c2c7adcc907027e33b15725322295866e7d413b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 06:08:33 -0400 Subject: [PATCH 15/20] fix initialize Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 43 ++++++++++--------- .../quantization/quant_scheme.py | 2 +- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 37ccb7e8..1a472a4d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -102,30 +102,33 @@ def initialize_module_for_quantization( return if scheme.input_activations is not None: - base_name = "input" - args = scheme.input_activations - observed_shape = (1, weight.shape[-1]) - observed_dtype = weight.dtype + _initialize_scale_zero_point( + module, + "input", + scheme.input_activations, + observed_shape=(1, weight.shape[-1]), + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) if scheme.weights is not None: - base_name = "weight" - args = scheme.weights - observed_shape = weight.shape - observed_dtype = weight.dtype - - if scheme.output_activations is not None: - base_name = "output" - args = scheme.output_activations - observed_shape = weight.shape[:-1] - observed_dtype = weight.dtype - - if not is_kv_cache_quant_scheme(scheme): _initialize_scale_zero_point( module, - base_name, - args, - observed_shape=observed_shape, - observed_dtype=observed_dtype, + "weight", + scheme.weights, + observed_shape=weight.shape, + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) + + output_is_kv_cache = is_kv_cache_quant_scheme(scheme) + if scheme.output_activations is not None and not output_is_kv_cache: + _initialize_scale_zero_point( + module, + "output", + scheme.output_activations, + observed_shape=weight.shape[:-1], + observed_dtype=weight.dtype, force_zero_point=force_zero_point, ) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 1a036e1c..8f8fb81c 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -23,8 +23,8 @@ QuantizationStrategy, QuantizationType, ) -from pydantic import BaseModel, ConfigDict, model_validator from loguru import logger +from pydantic import BaseModel, ConfigDict, model_validator __all__ = [ From 42ec9eea405cd9abf378109f234189cde8d6478c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 07:59:54 -0400 Subject: [PATCH 16/20] rename strategy_cdiv Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 8 ++++---- .../quantization/utils/helpers.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 1a472a4d..d6a08648 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -33,7 +33,7 @@ from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, - strict_divide, + strategy_cdiv, ) from compressed_tensors.utils import ( disable_hf_hook, @@ -196,7 +196,7 @@ def _initialize_scale_zero_point( raise ValueError("Group quant requires at least 1 observed dimension") group_size = quantization_args.group_size - num_groups = strict_divide(observed_shape[-1], group_size, strategy) + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) expected_shape = (*observed_shape[:-1], num_groups) # initialize activation ordering if applicable @@ -213,8 +213,8 @@ def _initialize_scale_zero_point( raise ValueError("Block quant requires at least 2 observed dimensions") block_structure = quantization_args.block_structure - num_rows = strict_divide(observed_shape[-2], block_structure[-2], strategy) - num_cols = strict_divide(observed_shape[-1], block_structure[-1], strategy) + num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) + num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) # 2. Identify quantization scale and zp dtype diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 8acf69a8..91b2c95b 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -30,6 +30,8 @@ from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module +from loguru import logger + __all__ = [ "infer_quantization_status", @@ -48,7 +50,7 @@ "calculate_qparams", "generate_gparam", "is_fp4", - "strict_divide", + "strategy_cdiv", ] # target the self_attn layer @@ -480,17 +482,17 @@ def generate_gparam( return global_scale.to(dtype).reshape([1]) -def strict_divide( - observed: int, divisor: int, strategy: Optional[QuantizationStrategy] = None +def strategy_cdiv( + value: int, divisor: int, strategy: Optional[QuantizationStrategy] = None ) -> int: - out = observed // divisor - if out * divisor != observed: + dividend = math.ceil(value / divisor) + if dividend * divisor != value: if strategy is not None: - raise ValueError( + logger.bind(log_once=True).warning( f"{strategy} quantization strategy requires strict division of " - f"weight/activation size {observed} and group/block size {divisor}. " + f"weight/activation size {value} and group/block size {divisor}. " "consider reducing the group/block size or ignoring modules with " f"weights not divisible by {divisor}" ) - return out + return dividend From ed0cc9392a70e8fc6ac731830405ebd4251389c3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 08:12:57 -0400 Subject: [PATCH 17/20] fix test Signed-off-by: Kyle Sayers --- .../quantization/utils/helpers.py | 3 +- .../test_utils/test_helpers.py | 38 ++++++------------- 2 files changed, 12 insertions(+), 29 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 91b2c95b..cfa85319 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -27,11 +27,10 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.utils import deprecated +from loguru import logger from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module -from loguru import logger - __all__ = [ "infer_quantization_status", diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index b9f9754c..d97e98cb 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -28,43 +28,27 @@ @pytest.mark.parametrize( - "keepdims,strategy,exp_shape", + "keepdims,qargs,exp_shape", [ + (False, QuantizationArgs(strategy="tensor"), torch.Size([1])), + (True, QuantizationArgs(strategy="channel"), torch.Size([1, 1])), + (True, QuantizationArgs(strategy="group", group_size=2), torch.Size([1, 1])), ( False, - QuantizationStrategy.TENSOR, - torch.Size( - [ - 1, - ] - ), + QuantizationArgs(strategy="block", block_structure=[1, 1]), + torch.Size([1]), ), - (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), - (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), - ( - False, - QuantizationStrategy.BLOCK, - torch.Size( - [ - 1, - ] - ), - ), - (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + (True, QuantizationArgs(strategy="token"), torch.Size([1, 1])), ], ) -def test_calculate_qparams(keepdims, strategy, exp_shape): +def test_calculate_qparams(keepdims, qargs, exp_shape): value = torch.randn(14, 5) min_val = torch.amin(value, dim=tuple(), keepdims=keepdims) max_val = torch.amax(value, dim=tuple(), keepdims=keepdims) - if strategy == QuantizationStrategy.GROUP: - args = QuantizationArgs(strategy=strategy, group_size=2) - else: - args = QuantizationArgs(strategy=strategy) - scale, zp = calculate_qparams(min_val, max_val, args) - assert scale.shape == exp_shape - assert zp.shape == exp_shape + scale, zp = calculate_qparams(min_val, max_val, qargs) + assert scale.shape == exp_shape + assert zp.shape == exp_shape def test_fused_global_scales(): From c415f0f2e6917d323a6f0f3c4f4b35f3046dedce Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 08:16:15 -0400 Subject: [PATCH 18/20] fix token shape Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index d6a08648..2e250ce9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -181,9 +181,12 @@ def _initialize_scale_zero_point( return # 1. Infer expected scale/zp shape - if strategy in (QuantizationStrategy.TENSOR, QuantizationStrategy.TOKEN): + if strategy == QuantizationStrategy.TENSOR: expected_shape = (1,) + elif strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + elif strategy == QuantizationStrategy.CHANNEL: if len(observed_shape) < 1: raise ValueError("Channel quant requires at least 1 observed dimension") @@ -217,6 +220,9 @@ def _initialize_scale_zero_point( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + else: + assert False, f"Unknown strategy {strategy}" + # 2. Identify quantization scale and zp dtype scale_dtype = observed_dtype From 954bb1cd7e34f460533ac040cfdb49746f755696 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 12:57:14 -0400 Subject: [PATCH 19/20] simplify Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 56 +++++++------------ .../quantization/utils/helpers.py | 23 +++++--- 2 files changed, 34 insertions(+), 45 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 83bc7fc2..0a031dd2 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -28,6 +28,7 @@ from compressed_tensors.quantization.utils import ( calculate_range, compute_dynamic_scales_and_zp, + strategy_cdiv, ) from torch.nn import Module @@ -257,45 +258,25 @@ def _process_quantization( global_scale=global_scale, ) # restore original shape - output = x_blocks.transpose(1, 2).reshape(original_shape) + x = x_blocks.transpose(1, 2).reshape(original_shape) elif args.strategy in ( QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): - output_dtype = dtype if dtype is not None else x.dtype - columns = x.size(-1) - - # TODO: make validation step for inputs - - while scale.ndim < 2: - # pad scale and zero point dims for slicing - scale = scale.unsqueeze(1) - zero_point = zero_point.unsqueeze(1) if zero_point is not None else None - if columns >= group_size: - if columns % group_size != 0: - raise ValueError( - "tensor column shape must be divisble " - f"by the given group_size {group_size}" - ) - - # permute groups + # activation ordering if g_idx is not None: perm = torch.argsort(g_idx) x = x.index_select(-1, perm) - # Maintain all dimensions except the last dim, which is divided by group_size - reshaped_dims = ( - ceil(x.shape[-1] / group_size), - group_size, - ) - # we should potentially be folding reshaped_dims[0] into x.shape[-2] - # in order to allow for multi-headed activations + # reshape into groups + num_groups = strategy_cdiv(x.size(-1), group_size, args.strategy, strict=True) + reshaped_dims = (num_groups, group_size) x = x.unflatten(-1, reshaped_dims) if do_quantize: - output = _quantize( + x = _quantize( x=x, scale=scale.unsqueeze(-1), zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, @@ -307,24 +288,25 @@ def _process_quantization( ) if do_dequantize: - input = output if do_quantize else x - output = _dequantize( - x_q=input, + x = _dequantize( + x_q=x, scale=scale.unsqueeze(-1), zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, global_scale=global_scale, ) - output = output.flatten(-2, -1) - output = output.to(output_dtype) + # undo reshape into groups + x = x.flatten(-2, -1) + x = x.to(output_dtype) - # unpermute groups + # undo activation ordering if g_idx is not None: - output = output.index_select(-1, g_idx) + inv_perm = torch.argsort(perm) + x = x.index_select(-1, inv_perm) else: # covers channel, token and tensor strategies if do_quantize: - output = _quantize( + x = _quantize( x=x, scale=scale, zero_point=zero_point, @@ -335,14 +317,14 @@ def _process_quantization( global_scale=global_scale, ) if do_dequantize: - output = _dequantize( - output if do_quantize else x, + x = _dequantize( + x_q=x, scale=scale, zero_point=zero_point, global_scale=global_scale, ) - return output + return x def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index cfa85319..4821f51c 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -482,16 +482,23 @@ def generate_gparam( def strategy_cdiv( - value: int, divisor: int, strategy: Optional[QuantizationStrategy] = None + value: int, + divisor: int, + strategy: Optional[QuantizationStrategy], + strict: bool = False, ) -> int: dividend = math.ceil(value / divisor) if dividend * divisor != value: - if strategy is not None: - logger.bind(log_once=True).warning( - f"{strategy} quantization strategy requires strict division of " - f"weight/activation size {value} and group/block size {divisor}. " - "consider reducing the group/block size or ignoring modules with " - f"weights not divisible by {divisor}" - ) + message = ( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {value} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + if strict: + raise ValueError(message) + + else: + logger.bind(log_once=True).warning(message) return dividend From 4718d79f01c945affdc8119e4a7e9dbc459ecb17 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 13:58:12 -0400 Subject: [PATCH 20/20] fix style Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/forward.py | 1 - src/compressed_tensors/quantization/quant_scheme.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 0a031dd2..f6cc8255 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -13,7 +13,6 @@ # limitations under the License. from functools import wraps -from math import ceil from typing import Optional import torch diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 8f8fb81c..5aeb9f7f 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -23,7 +23,6 @@ QuantizationStrategy, QuantizationType, ) -from loguru import logger from pydantic import BaseModel, ConfigDict, model_validator