Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 210 additions & 8 deletions auto_round/alg_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Algorithm extensions for AutoRound quantization.

This module provides enhanced quantization functions and wrapper classes that
extend the base AutoRound algorithm with additional capabilities primarily
targeted at INT2, MX-FP4, NV-FP4 and GGUF (double-quantized) formats.

Key components:
- :func:`wrapper_autoround`: Patches an ``AutoRound`` instance to use the
extended loss and wrapper functions when ``enable_alg_ext`` is set.
- :class:`WrapperLinearV2`: An enhanced version of :class:`WrapperLinear`
that initialises scale values using importance-matrix-weighted search.
- :class:`DQWrapperLinear`: A wrapper supporting GGUF double-quantization
(``int_asym_dq`` / ``int_sym_dq`` data types).
- Standalone quantization helpers: :func:`quant_tensor_sym`,
:func:`quant_mx`, :func:`nv_fp4`, :func:`quant_tensor_gguf_asym_dq`,
:func:`quant_tensor_gguf_sym_dq`.
- Iterative search functions adapted from llama.cpp:
:func:`iterative_wls_quant_search`, :func:`make_qp_quants`,
:func:`make_qp_new_quants`.
"""

import logging
import types
Expand All @@ -37,6 +57,17 @@


def wrapper_autoround(cls: AutoRound):
"""Patches an AutoRound compressor instance with algorithm-extension methods.

Depending on the compressor configuration this function may:
* Replace ``_register_act_max_hook`` with an imatrix-aware version.
* Replace ``_get_loss`` with a top-percent masked loss.
* Replace ``wrapper_block`` with :func:`wrapper_block_v2` (for INT/MX/NV)
or :func:`dq_wrapper_block` (for double-quant data types ending in ``dq``).

Args:
cls (AutoRound): The compressor instance to patch.
"""
cls._register_act_max_hook = types.MethodType(_register_act_max_hook_ext, cls)
if (
cls.sym
Expand Down Expand Up @@ -89,6 +120,23 @@ def _get_loss_ext(
mse_loss: Callable,
device: Union[str, torch.device] = "cpu",
):
"""Compute the masked MSE loss for quantization-aware tuning with algorithm extensions.

Computes the loss only on the top percentage of absolute differences between
the quantized output and the current output, optionally applying an attention mask.

Args:
self (AutoRound): The AutoRound instance providing configuration (e.g. attention mask,
AMP settings).
output_q (torch.Tensor): Quantized model output tensor.
current_output (torch.Tensor): Reference (float) model output tensor.
indices (torch.Tensor): Sample indices used to select the corresponding attention masks.
mse_loss (Callable): MSE loss function (unused; retained for API compatibility).
device (str | torch.device): Device on which to run computations. Defaults to ``"cpu"``.

Returns:
torch.Tensor: Scalar loss value.
"""
_, mask = get_abs_top_percent_mask(torch.abs(output_q - current_output))
autocast_ctx = nullcontext() if self.amp else autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype)
if self.attention_mask:
Expand Down Expand Up @@ -151,6 +199,19 @@ def quant_tensor_sym(

@torch.inference_mode()
def qdq_mxfp(tensor, max_val, max_norm, emax, ebits, mbits):
"""Quantizes and dequantizes a tensor using the MX-FP shared-exponent format.

Args:
tensor (torch.Tensor): Input tensor (float32).
max_val (torch.Tensor): Per-group maximum absolute value tensor.
max_norm (float): Maximum representable value in the target format.
emax (int): Maximum exponent for the format.
ebits (int): Number of exponent bits.
mbits (int): Number of mantissa bits.

Returns:
torch.Tensor: Quantized-dequantized tensor (same shape as ``tensor``).
"""
shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val))
shared_exp = torch.floor(shared_exp)
scale_emax = 2 ** (8 - 1) - 1
Expand All @@ -166,6 +227,21 @@ def qdq_mxfp(tensor, max_val, max_norm, emax, ebits, mbits):


def mx_init(tensor, bits, qw=None):
"""Initialises per-group scale factors for MX-FP quantization via grid search.

Searches scales in [0.5, 1.5] to find the scale that minimises the
weighted squared error between the input tensor and its quantized version.

Args:
tensor (torch.Tensor): Weight tensor to initialise scales for.
bits (int): MX-FP bit width (e.g. 4 for MXFP4).
qw (torch.Tensor | None): Per-element importance weights; if ``None``,
uniform weighting is used.

Returns:
torch.Tensor: Per-group optimal scale factors (same shape as the
per-group max-value tensor).
"""
data_type = "mx_fp" + str(bits)
ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type]
tensor = tensor.to(torch.float32)
Expand All @@ -188,6 +264,27 @@ def mx_init(tensor, bits, qw=None):


def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, max_scale=1.0, init_scale=1.0, **kwargs):
"""Quantizes and dequantizes a tensor using the NV-FP4 format.

Args:
tensor (torch.Tensor): Input weight tensor.
bits (int, optional): Bit width (must be 4). Defaults to 4.
group_size (int, optional): Group size for per-group quantization.
Defaults to 16.
v (float | torch.Tensor, optional): Rounding perturbation. Defaults to 0.
global_scale (torch.Tensor | None, optional): Pre-computed global FP8
scale factor; computed from the tensor if ``None``.
max_scale (float | torch.Tensor, optional): Per-group scale coefficient.
Defaults to 1.0.
init_scale (float | torch.Tensor, optional): Initial scale multiplier.
Defaults to 1.0.
**kwargs: Additional unused keyword arguments.

Returns:
tuple[torch.Tensor, torch.Tensor, None]: ``(qdq_result, scale, None)``
where ``qdq_result`` is the quantized-dequantized tensor and ``scale``
is the per-group FP4 scale.
"""
orig_dtype = tensor.dtype
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
if global_scale is None:
Expand All @@ -203,6 +300,20 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, max_scale=1.0,


def nv_init(tensor, bits, qw=None):
"""Initialises per-group scale factors for NV-FP4 quantization via grid search.

Searches scales in [0.5, 1.5] to find the scale that minimises the
weighted squared error between the input tensor and its NV-FP4 quantized
version.

Args:
tensor (torch.Tensor): Weight tensor to initialise scales for.
bits (int): Bit width (must be 4).
qw (torch.Tensor | None): Per-element importance weights.

Returns:
torch.Tensor: Per-group optimal scale factors.
"""
tensor = tensor.to(torch.float32)
qdq_t, dummy_scale, _ = nv_fp4(tensor, bits=4, group_size=16, v=0, max_scale=1.0)
best_loss = torch.sum((qdq_t - tensor) ** 2 * qw, dim=-1)
Expand Down Expand Up @@ -421,14 +532,22 @@ def _qdq_weight(self, value, min_scale, max_scale):


def wrapper_block_v2(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs):
"""Wraps the layers in the given block with a custom Wrapper module.
"""Wraps quantizable layers in a block with :class:`WrapperLinearV2`.

This variant uses :class:`WrapperLinearV2` instead of :class:`WrapperLinear`
to leverage importance-matrix-based scale initialisation.

Args:
block: The input block containing linear and conv1d layers to be wrapped.
enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled.
block (torch.nn.Module): Block whose layers should be wrapped.
enable_minmax_tuning (bool): Whether to enable min-max scale tuning.
enable_norm_bias_tuning (bool): Whether to enable norm/bias tuning.
device (str, optional): Computation device. Defaults to ``"cpu"``.
**kwargs: Additional keyword arguments forwarded to
:class:`WrapperLinearV2`.

Returns:
list: A list of names of the wrapped layers and unwrapped layers.
tuple[list[str], list[str]]: ``(quantized_layers, unquantized_layers)``
where each is a list of layer name strings.
"""
quantized_layers = []
unquantized_layers = []
Expand Down Expand Up @@ -466,7 +585,31 @@ def wrapper_block_v2(block, enable_minmax_tuning, enable_norm_bias_tuning, devic


def _register_act_max_hook_ext(self, model):
"""Registers forward hooks to collect importance matrices and activation maxima.

This extended version also registers an imatrix (importance matrix) hook on
all supported layer types and an activation-max hook for static activation
quantization.

Args:
self (AutoRound): The compressor instance on which hooks are registered.
model (torch.nn.Module): The model to register hooks on.

Returns:
list: List of :class:`torch.utils.hooks.RemovableHook` handles.
"""

def get_act_max_hook(module, input, output):
"""Forward hook that tracks the per-group activation maximum for a module.

Updates ``module.act_max`` in-place with the element-wise maximum of
absolute activation values observed so far.

Args:
module (torch.nn.Module): The module being observed.
input: Module input(s); the first element is used when a tuple/list.
output: Module output (unused).
"""
if isinstance(input, (tuple, list)):
input = input[0]
if input.numel() == 0:
Expand All @@ -483,6 +626,16 @@ def get_act_max_hook(module, input, output):
module.act_max = torch.max(act_max, module.act_max)

def get_imatrix_hook(module, input, output):
"""Forward hook that accumulates the importance matrix (sum of squared activations).

Updates ``module.imatrix`` in-place by accumulating the column-wise sum of
squared input activations across all calibration batches.

Args:
module (torch.nn.Module): The module being observed.
input: Module input(s); the first element is used when a tuple/list.
output: Module output (unused).
"""
input = input[0] if isinstance(input, (tuple, list)) else input
flattened = input.reshape(-1, input.shape[-1]).to(torch.float32)
squared = torch.sum(torch.pow(flattened, 2), dim=0).to(torch.float32)
Expand Down Expand Up @@ -531,6 +684,22 @@ def get_imatrix_hook(module, input, output):


def make_qp_quants(nmax, data, quant_weights, v=0):
"""Performs quantization scale/offset search using positive-only quantization.

Finds the optimal scale for quantizing ``data`` to integer values in
``[0, nmax]`` by minimising the weighted sum of squared errors across
candidate scales.

Args:
nmax (int): Maximum quantized integer value.
data (torch.Tensor): Input tensor to quantize (shape: ``[n_groups, gs]``).
quant_weights (torch.Tensor): Per-element importance weights.
v (float | torch.Tensor, optional): Rounding perturbation. Defaults to 0.

Returns:
tuple[torch.Tensor, torch.Tensor]: ``(scale, L)`` where ``scale`` is the
optimal dequantization scale and ``L`` is the quantized integer tensor.
"""
data = data.to(torch.float32)
quant_weights = quant_weights.to(torch.float32)
group_max = torch.max(data, dim=-1, keepdim=True)[0]
Expand Down Expand Up @@ -640,6 +809,32 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u


def make_qp_new_quants(data, orig_scale, orig_mins, quant_weights, bits=4, super_bits=6, data_v=0, scale_v=0, min_v=0):
"""Secondary (super) scale quantization for double-quantization schemes.

Quantizes per-group primary scales using a secondary ``super_bits``-bit
quantization and returns the optimal secondary scale and quantized scale
indices.

Args:
data (torch.Tensor): Original weight tensor (float32).
orig_scale (torch.Tensor): Per-group primary scale tensor.
orig_mins (torch.Tensor | None): Per-group minimum value tensor (for
asymmetric schemes).
quant_weights (torch.Tensor): Per-element importance weights.
bits (int, optional): Primary quantization bits. Defaults to 4.
super_bits (int, optional): Secondary quantization bits. Defaults to 6.
data_v (float | torch.Tensor, optional): Data rounding perturbation.
Defaults to 0.
scale_v (float | torch.Tensor, optional): Scale rounding perturbation.
Defaults to 0.
min_v (float | torch.Tensor, optional): Min rounding perturbation.
Defaults to 0.

Returns:
tuple[torch.Tensor, torch.Tensor]: ``(d_scale, L)`` where ``d_scale`` is
the secondary dequantization scale and ``L`` is the quantized scale
index tensor.
"""
nmax = 2**super_bits - 1
maxq = 2**bits - 1
minq = 0
Expand Down Expand Up @@ -1109,14 +1304,21 @@ def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, wmin_v=None, it


def dq_wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs):
"""Wraps the layers in the given block with a custom Wrapper module.
"""Wraps quantizable layers in a block with :class:`DQWrapperLinear`.

This variant is used for GGUF double-quantized data types.

Args:
block: The input block containing linear and conv1d layers to be wrapped.
enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled.
block (torch.nn.Module): Block whose layers should be wrapped.
enable_minmax_tuning (bool): Whether to enable min-max scale tuning.
enable_norm_bias_tuning (bool): Whether to enable norm/bias tuning.
device (str, optional): Computation device. Defaults to ``"cpu"``.
**kwargs: Additional keyword arguments forwarded to
:class:`DQWrapperLinear`.

Returns:
list: A list of names of the wrapped layers and unwrapped layers.
tuple[list[str], list[str]]: ``(quantized_layers, unquantized_layers)``
where each is a list of layer name strings.
"""
quantized_layers = []
unquantized_layers = []
Expand Down
30 changes: 24 additions & 6 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AutoRound entry point providing a unified factory for model quantization.

This module defines the ``AutoRound`` class, which serves as a factory that
automatically selects the appropriate backend compressor (LLM, MLLM, Diffusion,
or Adam-based) depending on the model type and configuration, as well as
deprecated backward-compatible subclass aliases.
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union
Expand Down Expand Up @@ -221,14 +228,21 @@ def _sampling_inputs(
"""Samples inputs based on the given indices and sequence length.

Args:
input_ids: The list of input tensor containing input_ids.
input_others: A dictionary containing other input data.
indices: The indices to sample from the input.
seqlen: The sequence length.
input_ids (list[torch.Tensor]): List of input tensors containing token IDs.
input_others (dict): Dictionary containing additional input data (e.g.,
attention masks, positional inputs).
indices (list[int]): Indices to select from the input list.
seqlen (int): Target sequence length for slicing/truncating.
batch_dim (int, optional): Dimension along which to concatenate batches.
Defaults to 0.
share_cache_keys (tuple, optional): Keys in ``input_others`` whose values
are shared across all samples and should not be per-sample indexed.
Defaults to empty tuple.

Returns:
current_input_ids: The sampled input IDs.
current_input_others: The sampled other input data.
tuple[torch.Tensor, dict]: A tuple of ``(current_input_ids, current_input_others)``
where ``current_input_ids`` is the concatenated input-ID tensor and
``current_input_others`` is the sampled auxiliary-input dictionary.
"""
current_input_ids = [input_ids[i] for i in indices]

Expand Down Expand Up @@ -350,6 +364,7 @@ def __init__(
seed: int = 42,
**kwargs,
):
"""Initialize AutoRoundLLM. Delegates all arguments to LLMCompressor.__init__."""
local_args = {k: v for k, v in locals().items() if k not in ("local_args", "kwargs", "self")}
super().__init__(
**local_args,
Expand Down Expand Up @@ -440,6 +455,7 @@ def __init__(
optimizer="AdamW",
**kwargs,
):
"""Initialize AutoRoundAdam. Delegates all arguments to AdamCompressor.__init__."""
local_args = {k: v for k, v in locals().items() if k not in ("local_args", "kwargs", "self")}
super().__init__(
**local_args,
Expand Down Expand Up @@ -531,6 +547,7 @@ def __init__(
seed: int = 42,
**kwargs,
):
"""Initialize AutoRoundMLLM. Delegates all arguments to MLLMCompressor.__init__."""
local_args = {k: v for k, v in locals().items() if k not in ("local_args", "kwargs", "self")}
super().__init__(
**local_args,
Expand Down Expand Up @@ -597,6 +614,7 @@ def __init__(
seed: int = 42,
**kwargs,
):
"""Initialize AutoRoundDiffusion. Delegates all arguments to DiffusionCompressor.__init__."""
local_args = {k: v for k, v in locals().items() if k not in ("local_args", "kwargs", "self")}
super().__init__(
**local_args,
Expand Down
Loading
Loading