diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2e539b070..f6cc82554 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -13,7 +13,6 @@ # limitations under the License. from functools import wraps -from math import ceil from typing import Optional import torch @@ -28,6 +27,7 @@ from compressed_tensors.quantization.utils import ( calculate_range, compute_dynamic_scales_and_zp, + strategy_cdiv, ) from torch.nn import Module @@ -257,53 +257,25 @@ def _process_quantization( global_scale=global_scale, ) # restore original shape - output = x_blocks.transpose(1, 2).reshape(original_shape) + x = x_blocks.transpose(1, 2).reshape(original_shape) elif args.strategy in ( QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): - output_dtype = dtype if dtype is not None else x.dtype - output = torch.zeros_like(x).to(output_dtype) - columns = output.shape[-1] - - # TODO: make validation step for inputs - - while scale.ndim < 2: - # pad scale and zero point dims for slicing - scale = scale.unsqueeze(1) - zero_point = zero_point.unsqueeze(1) if zero_point is not None else None - - if columns >= group_size: - if columns % group_size != 0: - raise ValueError( - "tensor column shape must be divisble " - f"by the given group_size {group_size}" - ) - - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has been initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - num_groups = int(ceil(columns / group_size)) - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] + # activation ordering + if g_idx is not None: perm = torch.argsort(g_idx) x = x.index_select(-1, perm) - # Maintain all dimensions except the last dim, which is divided by group_size - reshaped_dims = ( - ceil(x.shape[-1] / group_size), - group_size, - ) + # reshape into groups + num_groups = strategy_cdiv(x.size(-1), group_size, args.strategy, strict=True) + reshaped_dims = (num_groups, group_size) x = x.unflatten(-1, reshaped_dims) if do_quantize: - output = _quantize( + x = _quantize( x=x, scale=scale.unsqueeze(-1), zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, @@ -315,24 +287,25 @@ def _process_quantization( ) if do_dequantize: - input = output if do_quantize else x - output = _dequantize( - x_q=input, + x = _dequantize( + x_q=x, scale=scale.unsqueeze(-1), zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, global_scale=global_scale, ) - output = output.flatten(start_dim=-2) - output = output.to(output_dtype) + # undo reshape into groups + x = x.flatten(-2, -1) + x = x.to(output_dtype) - if not is_column_order: + # undo activation ordering + if g_idx is not None: inv_perm = torch.argsort(perm) - output = output.index_select(-1, inv_perm) + x = x.index_select(-1, inv_perm) else: # covers channel, token and tensor strategies if do_quantize: - output = _quantize( + x = _quantize( x=x, scale=scale, zero_point=zero_point, @@ -343,14 +316,14 @@ def _process_quantization( global_scale=global_scale, ) if do_dequantize: - output = _dequantize( - output if do_quantize else x, + x = _dequantize( + x_q=x, scale=scale, zero_point=zero_point, global_scale=global_scale, ) - return output + return x def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5350b4a2c..2e250ce92 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,10 +14,8 @@ import logging -import math -import warnings from enum import Enum -from typing import Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -26,12 +24,17 @@ from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, + DynamicType, QuantizationArgs, QuantizationStrategy, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import ( + is_fp4, + is_kv_cache_quant_scheme, + strategy_cdiv, +) from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -73,10 +76,8 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return if is_attention_module(module): @@ -84,38 +85,52 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" + ) + return + if scheme.input_activations is not None: _initialize_scale_zero_point( module, "input", scheme.input_activations, + observed_shape=(1, weight.shape[-1]), + observed_dtype=weight.dtype, force_zero_point=force_zero_point, ) if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations - ) + _initialize_scale_zero_point( + module, + "weight", + scheme.weights, + observed_shape=weight.shape, + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) + + output_is_kv_cache = is_kv_cache_quant_scheme(scheme) + if scheme.output_activations is not None and not output_is_kv_cache: + _initialize_scale_zero_point( + module, + "output", + scheme.output_activations, + observed_shape=weight.shape[:-1], + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -138,18 +153,21 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, + observed_shape: Tuple[int], + observed_dtype: torch.dtype, force_zero_point: bool = True, ): - if quantization_args.dynamic is True: - return + strategy = quantization_args.strategy + dynamic = quantization_args.dynamic + actorder = quantization_args.actorder + device = get_execution_device(module) # avoid performing intialization ops on cpu - # initialize on execution device to avoid performing quantized ops on cpu - device = get_execution_device(module) + # Skip all intialization for fully dynamic quantization + if dynamic is True: + return - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + # 0. Create global scale for tensor-group quantization + if strategy == QuantizationStrategy.TENSOR_GROUP: init_global_scale = Parameter( torch.empty(1, dtype=torch.float32, device=device), requires_grad=False, @@ -158,56 +176,55 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # 2. Infer expected scale/zero point shape - if quantization_args.strategy == QuantizationStrategy.TOKEN: + # Skip scale/zp initialization for locally dynamic quantization + if dynamic == DynamicType.LOCAL: + return + + # 1. Infer expected scale/zp shape + if strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif strategy == QuantizationStrategy.TOKEN: expected_shape = (1, 1) + + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 1: + raise ValueError("Channel quant requires at least 1 observed dimension") + + expected_shape = (observed_shape[-2], 1) + + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) + + # initialize activation ordering if applicable + if actorder == ActivationOrdering.GROUP: + init_g_idx = Parameter( + torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + elif strategy == QuantizationStrategy.BLOCK: + assert quantization_args.block_structure is not None + if len(observed_shape) < 2: + raise ValueError("Block quant requires at least 2 observed dimensions") + + block_structure = quantization_args.block_structure + num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) + num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) + expected_shape = (num_rows, num_cols) + else: - expected_shape = 1 - - if base_name == "weight" and weight_shape is not None: - if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) - only for weights - expected_shape = (weight_shape[0], 1) - elif quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - # GROUP/TENSOR_GROUP for both weights and activations - num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) - expected_shape = (weight_shape[0], max(num_groups, 1)) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only - # for weights - if quantization_args.block_structure is None: - raise ValueError( - "Block quantization requires block_structure to be specified" - ) - block_height, block_width = quantization_args.block_structure - rows, cols = weight_shape[-2], weight_shape[-1] - num_rows_blocks = math.ceil(rows / block_height) - num_cols_blocks = math.ceil(cols / block_width) - - # Warn if dimensions don't divide evenly - if rows % block_height != 0 or cols % block_width != 0: - warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide" - f"evenly by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization" - "quality.", - UserWarning, - ) - - expected_shape = (num_rows_blocks, num_cols_blocks) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - warnings.warn( - f"BLOCK quantization not supported for {base_name} activations. " - f"Falling back to tensor-level quantization.", - UserWarning, - ) - expected_shape = 1 + assert False, f"Unknown strategy {strategy}" - # 3. Identify quantization scale and zp dtype - scale_dtype = module.weight.dtype + # 2. Identify quantization scale and zp dtype + scale_dtype = observed_dtype if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -223,14 +240,12 @@ def _initialize_scale_zero_point( scale_dtype = torch.bfloat16 zp_dtype = quantization_args.pytorch_dtype() - # 4. Initializes empty scale, zero point, and g_idx parameters for the module - # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # 3. Initializes scale/zp for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( @@ -239,16 +254,6 @@ def _initialize_scale_zero_point( ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) - # only grouped activation ordering has g_idx - if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353b..1ee294870 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -262,6 +262,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": actorder = model.actorder dynamic = model.dynamic observer = model.observer + block_structure = model.block_structure # infer strategy if strategy is None: @@ -277,23 +278,29 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": "strategy='group' and group_size = -1 for 'channel'" ) - # validate strategy and group - if strategy == QuantizationStrategy.GROUP: - if group_size is None or group_size <= 0: - raise ValueError( - f"strategy {strategy} requires group_size to be " - "set to a positive value" - ) - if ( - group_size is not None - and group_size > 0 - and strategy - not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP) - ): - raise ValueError("group_size requires strategy to be set to 'group'") - - # validate activation ordering and strategy - if actorder is not None and strategy != QuantizationStrategy.GROUP: + # validate block strategy and structure + has_block_strategy = strategy == QuantizationStrategy.BLOCK + has_block_structure = block_structure is not None + if has_block_strategy != has_block_structure: + raise ValueError( + "`strategy = block` requires `block_structure != None`, and vice versa." + f" Instead got `strategy={strategy}` and " + f"`block_structure={block_structure}`" + ) + + # validate group strategy + has_group_strategy = strategy in ( + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ) + has_group_size = group_size is not None and group_size > 0 + has_actorder = actorder is not None + if has_group_strategy != has_group_size: + raise ValueError( + "`strategy = group` requires `group_size != None`, and vice versa. " + f"Instead got `strategy={strategy}` and `group_size={group_size}`" + ) + if has_actorder and not has_group_strategy: raise ValueError( "Must use group quantization strategy in order to apply " "activation ordering" diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c0..5aeb9f7f4 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -60,24 +60,19 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": format = model.format if inputs is not None: - if inputs.strategy not in ( - QuantizationStrategy.TOKEN, - QuantizationStrategy.TENSOR, - QuantizationStrategy.GROUP, - QuantizationStrategy.TENSOR_GROUP, - ): - if ( - inputs.strategy == QuantizationStrategy.GROUP - and inputs.dynamic is True - ): - raise NotImplementedError( - "Static and local group-wise activation " - "quantization is not supported" - ) - - raise NotImplementedError( - f"Using {inputs.strategy} strategy is not supported for " - "activation quantization" + if inputs.strategy == QuantizationStrategy.CHANNEL: + raise ValueError( + "Channel-wise activation quantization is equivalent to " + "tensor/token-wise activation quantization, please use one of " + "those. If you mean to quantize each activation value " + "individually, please use group quantization with `group_size = 1`" + ) + + if inputs.strategy == QuantizationStrategy.BLOCK: + raise ValueError( + "Block-wise activation quantization is not supported. If you mean " + "to quantize each activation value individually, please use group " + "quantization with `group_size = 1`" ) if inputs.actorder is not None: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..4821f51c4 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -27,6 +27,7 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.utils import deprecated +from loguru import logger from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module @@ -48,6 +49,7 @@ "calculate_qparams", "generate_gparam", "is_fp4", + "strategy_cdiv", ] # target the self_attn layer @@ -477,3 +479,26 @@ def generate_gparam( max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) + + +def strategy_cdiv( + value: int, + divisor: int, + strategy: Optional[QuantizationStrategy], + strict: bool = False, +) -> int: + dividend = math.ceil(value / divisor) + if dividend * divisor != value: + message = ( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {value} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + if strict: + raise ValueError(message) + + else: + logger.bind(log_once=True).warning(message) + + return dividend diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index f3321cd40..09010af06 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -95,7 +95,7 @@ def test_forward_quantize( @pytest.mark.parametrize( - "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale", + "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale,batch_size", [ ( 4, @@ -106,6 +106,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 4, @@ -116,6 +117,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 4, @@ -126,6 +128,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -136,6 +139,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 8, @@ -146,6 +150,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -156,6 +161,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -166,6 +172,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -176,17 +183,41 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, + ), + ( + 8, + "int", + QuantizationStrategy.GROUP, + 128, + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), + make_dummy_g_idx(1024, 128), + None, + 5, ), ], ) -def test_fake_quantize_2d( - num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale +def test_fake_quantize( + num_bits, + type, + strategy, + group_size, + scale, + zero_point, + g_idx, + global_scale, + batch_size, ): args = QuantizationArgs( num_bits=num_bits, type=type, strategy=strategy, group_size=group_size ) - x = torch.rand((512, 1024)) + if batch_size is None: + x = torch.rand((512, 1024)) + else: + x = torch.rand((batch_size, 512, 1024)) + fake_quantize( x=x, scale=scale, diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index b9f9754c6..d97e98cb3 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -28,43 +28,27 @@ @pytest.mark.parametrize( - "keepdims,strategy,exp_shape", + "keepdims,qargs,exp_shape", [ + (False, QuantizationArgs(strategy="tensor"), torch.Size([1])), + (True, QuantizationArgs(strategy="channel"), torch.Size([1, 1])), + (True, QuantizationArgs(strategy="group", group_size=2), torch.Size([1, 1])), ( False, - QuantizationStrategy.TENSOR, - torch.Size( - [ - 1, - ] - ), + QuantizationArgs(strategy="block", block_structure=[1, 1]), + torch.Size([1]), ), - (True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])), - (True, QuantizationStrategy.GROUP, torch.Size([1, 1])), - ( - False, - QuantizationStrategy.BLOCK, - torch.Size( - [ - 1, - ] - ), - ), - (True, QuantizationStrategy.TOKEN, torch.Size([1, 1])), + (True, QuantizationArgs(strategy="token"), torch.Size([1, 1])), ], ) -def test_calculate_qparams(keepdims, strategy, exp_shape): +def test_calculate_qparams(keepdims, qargs, exp_shape): value = torch.randn(14, 5) min_val = torch.amin(value, dim=tuple(), keepdims=keepdims) max_val = torch.amax(value, dim=tuple(), keepdims=keepdims) - if strategy == QuantizationStrategy.GROUP: - args = QuantizationArgs(strategy=strategy, group_size=2) - else: - args = QuantizationArgs(strategy=strategy) - scale, zp = calculate_qparams(min_val, max_val, args) - assert scale.shape == exp_shape - assert zp.shape == exp_shape + scale, zp = calculate_qparams(min_val, max_val, qargs) + assert scale.shape == exp_shape + assert zp.shape == exp_shape def test_fused_global_scales():