Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
faa014d
refine transform config api.
lkk12014402 Mar 24, 2026
00d248a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
3842e33
Update auto_round/experimental/transform/helper.py
lkk12014402 Mar 25, 2026
f4ff447
Update auto_round/experimental/transform/helper.py
lkk12014402 Mar 25, 2026
1351df9
Update auto_round/experimental/transform/apply.py
lkk12014402 Mar 25, 2026
cb2624a
Update auto_round/experimental/transform/apply.py
lkk12014402 Mar 25, 2026
a5b72a1
Update auto_round/compressors/base.py
lkk12014402 Mar 25, 2026
351ea57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
58da3d2
add more ut.
lkk12014402 Mar 25, 2026
79b7cf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
76d2806
fix typo.
lkk12014402 Mar 25, 2026
e9cccd6
replace to .
lkk12014402 Mar 25, 2026
93ad904
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
9982de6
fix typo
lkk12014402 Mar 25, 2026
90bb890
add initial hadamard tranform document.
lkk12014402 Mar 25, 2026
eb8b302
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
53c460a
update hadamard transform api for better usage.
lkk12014402 Mar 25, 2026
bde783f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
03d5bc4
fix typo
lkk12014402 Mar 25, 2026
b7ec9d7
update hadamard transform doc.
lkk12014402 Mar 26, 2026
c8d505b
format code.
lkk12014402 Mar 26, 2026
1e05744
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2026
9d71ca6
update hadamard transform doc.
lkk12014402 Mar 26, 2026
52124fc
update doc.
lkk12014402 Mar 26, 2026
a6ca673
update doc.
lkk12014402 Mar 26, 2026
8a9347b
update doc.
lkk12014402 Mar 26, 2026
8a95b12
update doc.
lkk12014402 Mar 26, 2026
1ff5c34
Merge branch 'main' into refine_transform_api
lkk12014402 Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __new__(
enable_alg_ext: bool = False,
disable_opt_rtn: bool | None = None,
low_cpu_mem_usage: bool = True,
transform_config: dict | None = None,
hadamard_config: dict | None = None,
**kwargs,
) -> BaseCompressor:
"""Initialize AutoRound with quantization and tuning configuration.
Expand Down
9 changes: 7 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tqdm import tqdm
from transformers import AutoConfig, set_seed

import auto_round.experimental.transform.helper as transform_helper
from auto_round import envs
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
from auto_round.compressors.shard_writer import shard_writer
Expand All @@ -56,6 +57,8 @@
)
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, update_block_global_scale_if_needed
from auto_round.experimental.transform.hadamard_config import HadamardConfig
from auto_round.experimental.transform.helper import normalize_hadamard_config
from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG
from auto_round.formats import OutputFormat, get_formats
from auto_round.logger import logger
Expand Down Expand Up @@ -150,7 +153,7 @@
"super_bits",
"super_group_size",
"to_quant_block_names",
"transform_config",
"hadamard_config",
)


Expand Down Expand Up @@ -201,6 +204,7 @@ def __init__(
disable_opt_rtn: bool | None = None,
seed: int = 42,
low_cpu_mem_usage: bool = True,
hadamard_config: str | dict | HadamardConfig | None = None,
**kwargs,
):
"""Initialize AutoRound with quantization and tuning configuration.
Expand Down Expand Up @@ -552,7 +556,7 @@ def __init__(
except (ImportError, ModuleNotFoundError):
logger.error("algorithm extension import error, fallback to default mode")

self.transform_config = kwargs.pop("transform_config", {})
self.hadamard_config = normalize_hadamard_config(hadamard_config)

def _gen_auto_scheme(self) -> dict[str, dict]:
if self.mllm:
Expand Down Expand Up @@ -3362,6 +3366,7 @@ def save_quantized(
serialization_dict = {}
for key in SERIALIZATION_KEYS:
serialization_dict[key] = getattr(self, key)

from auto_round.version import __version__

serialization_dict["autoround_version"] = __version__
Expand Down
2 changes: 1 addition & 1 deletion auto_round/experimental/qmodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, TransformMXFP4QuantLinear
from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, HadamardMXFP4QuantLinear
from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear
from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear
4 changes: 2 additions & 2 deletions auto_round/experimental/qmodules/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor:
return unpacked_data


class TransformMXFP4QuantLinear(MXFP4QuantLinear):
class HadamardMXFP4QuantLinear(MXFP4QuantLinear):
"""
Quantized linear layer using the MXFP4 quantization scheme.
"""
Expand All @@ -206,7 +206,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enable_transform = True
self.register_buffer(
"transform_matrix",
"hadamard_matrix",
torch.empty(
self.group_size,
self.group_size,
Expand Down
113 changes: 77 additions & 36 deletions auto_round/experimental/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,82 @@
import tqdm

from auto_round.experimental.qmodules.mx import MXQuantLinearBase
from auto_round.experimental.transform.transform_config import TransformConfig
from auto_round.experimental.transform.transforms import build_transform
from auto_round.experimental.transform.hadamard_config import HadamardConfig
from auto_round.experimental.transform.hadamards import build_hadamard_transform
from auto_round.experimental.transform.helper import normalize_hadamard_config

__all__ = ["apply_transform"]
__all__ = ["apply_hadamard_transform"]


def apply_transform(model: torch.nn.Module, config: TransformConfig, use_tqdm=True, desc=None):
def apply_hadamard_transform(
model: torch.nn.Module,
config: str | dict | HadamardConfig | None,
need_calibration: bool = False,
location: str = "weight",
use_tqdm=True,
desc=None,
):
"""
Apply a transform config to a model. Add weight transforms and
activation transforms are attached as submodules and trigger via pytorch hooks

:param model: model to apply config to
:param config: transform config to apply
Apply a transform configuration to a model.

Weight and activation transforms are attached as submodules and are
triggered via PyTorch hooks.

:param model: Model to which the transform configuration will be applied.
:param config: Transform configuration to apply. Supported values are:
* ``str``: A named/preset transform configuration. In this case,
``scheme`` is typically required so that the preset can be
resolved to a concrete quantization/transform configuration.
* ``dict``: A raw configuration mapping that will be normalized
(via :func:`normalize_hadamard_config`) and then passed to
:class:`TransformConfig`.
* :class:`TransformConfig`: An existing configuration instance.
This will be used to construct the final configuration after
normalization.
* ``None``: Uses the default behavior of
:func:`_normalize_hadamard_config` (for example, inferring a
configuration from ``scheme`` or other project defaults), if
supported.
:param scheme: Optional quantization/transform scheme identifier used
when ``config`` is a ``str`` (and, if supported, when it is
``None``) to determine which concrete configuration to build.
Ignored when ``config`` is already a ``dict`` or
:class:`TransformConfig`.
:param use_tqdm: If ``True``, wrap the per-module application in a
tqdm progress bar.
:param desc: Optional description string to show in the tqdm progress
bar. If ``None``, a description will be derived from
``config.transform_type``.
"""

config = normalize_hadamard_config(config)
if not isinstance(config, HadamardConfig):
config = HadamardConfig(**config)

modules_config = [
(name, module, config)
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear) or isinstance(module, MXQuantLinearBase)
]

desc = f"Applying {config.transform_type} transforms" if desc is None else desc
desc = f"Applying {config.hadamard_type} transforms" if desc is None else desc
for name, module, config in tqdm.tqdm(modules_config, desc=desc, disable=(not use_tqdm)):
if "lm_head" in name:
continue
_apply_to_module(model, module, config)
_apply_to_module(model, module, config, need_calibration, location)

# attach config to model for compression/serialization
setattr(model, "transform_config", config)
setattr(model, "hadamard_config", config)

return model


def _apply_to_module(
model: torch.nn.Module,
module: torch.nn.Module,
config: TransformConfig,
config: HadamardConfig,
need_calibration: bool = False,
location: str = "weight",
):
"""
Create transforms and apply them to the module
Expand All @@ -51,23 +90,24 @@ def _apply_to_module(
"""

# create transform as submodule
transform_name = "transform_matrix"
hadamard_name = config.hadamard_type

if config.location == "input":
from auto_round.experimental.transform.triton.utils import is_triton_kernel_available
if location == "input":
from auto_round.experimental.transform.helper import is_triton_kernel_available

# activation needs transpose
inp_transform = build_transform(
input_hadamard_transform = build_hadamard_transform(
**config.dict(),
location="input",
inverse=True,
device="cpu",
precision=module.dtype,
)

if config.transform_type != "random_hadamard":
transform_weight = inp_transform.weight
if config.hadamard_type != "random_hadamard":
hadamard_weight = input_hadamard_transform.weight
else:
transform_weight = None
hadamard_weight = None

if is_triton_kernel_available():
from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper
Expand All @@ -80,7 +120,7 @@ def input_hook(self, args):
qdq_input, _ = mxfp4_forward_kernel_wrapper(
x_flat,
(
transform_weight if transform_weight is not None else self.transform_matrix.T
hadamard_weight if hadamard_weight is not None else self.hadamard_matrix.T
), # this matrix from w_transform, needs transpose
)
return qdq_input.reshape(orig_shape)
Expand All @@ -97,61 +137,62 @@ def input_hook(self, args):

ori_shape = input.shape

if transform_weight is not None:
input = input.view(-1, transform_weight.shape[0])
return _multihead_matmul(input, transform_weight.to(input.device)).view(ori_shape)
if hadamard_weight is not None:
input = input.view(-1, hadamard_weight.shape[0])
return _multihead_matmul(input, hadamard_weight.to(input.device)).view(ori_shape)
else:
input = input.view(-1, self.transform_matrix.shape[0])
return _multihead_matmul(input, self.transform_matrix.T).view(ori_shape)
input = input.view(-1, self.hadamard_matrix.shape[0])
return _multihead_matmul(input, self.hadamard_matrix.T).view(ori_shape)

# for fused transform + quantization kernel
module.pre_dequantized_input = False
module.register_forward_pre_hook(input_hook, prepend=True)

elif config.location == "weight":
elif location == "weight":
# eagerly apply transformation to weight
# fuse transform into weight
assert hasattr(module, "weight")

w_transform = build_transform(
weight_hadamard_transform = build_hadamard_transform(
**config.dict(),
location="weight",
device=module.weight.device,
precision=module.weight.dtype,
)

# need save random hadamard matrix needed when inference
if config.transform_type == "random_hadamard":
module.register_module(transform_name, w_transform)
if config.hadamard_type == "random_hadamard":
module.register_module(config.hadamard_type, weight_hadamard_transform)
# for saving transform weight
from auto_round.experimental.transform.patch_modules import patch_quantlinear

patch_quantlinear()
patch_quantlinear(config.hadamard_type)

if config.need_calibration:
if need_calibration:
# for training, the weight changes with every forward pass
# for autoround tuning: patch wrapper linear qdq_weight func
from auto_round.experimental.transform.patch_modules import (
patch_wrapperlinear_to_apply_transform,
patch_wrapperwalayer_forward_to_apply_transform,
)

inp_transform = build_transform(
input_hadamard_transform = build_hadamard_transform(
**config.dict(),
location="input",
inverse=True,
device=module.weight.device,
precision=module.weight.dtype,
)

patch_wrapperlinear_to_apply_transform(w_transform, inp_transform)
patch_wrapperwalayer_forward_to_apply_transform(inp_transform)
patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform)
patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform)

else:
# transform is no longer needed (unfusing is not supported)
# delattr(module, transform_name)
# fuse transform into weight
with torch.no_grad():
getattr(module, "weight").copy_(w_transform(module.weight).to(module.weight.device))
getattr(module, "weight").copy_(weight_hadamard_transform(module.weight).to(module.weight.device))

else:
# TODO: apply transform to output/q/k
Expand Down
28 changes: 28 additions & 0 deletions auto_round/experimental/transform/hadamard_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# # Copyright (C) 2026 Intel Corporation
# # SPDX-License-Identifier: Apache-2.0

from pydantic import BaseModel, Field, field_validator

__all__ = ["HadamardConfig"]


class HadamardConfig(BaseModel):
"""
Configuration of transforms to be applied to a model. This config is to be
serialized within a model's `config.json` file
"""

block_size: int = Field(default=32)

hadamard_type: str = Field(default="hadamard")

# for random hadamard transform
random_seed: bool = Field(default=False, exclude=True)

@field_validator("hadamard_type")
@classmethod
def validate_hadamard_type(cls, v: str) -> str:
allowed = {"hadamard", "random_hadamard"}
if v not in allowed:
raise ValueError(f"Unsupported hadamard_type: {v}. Supported values: {sorted(allowed)}")
return v
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor):
return (
(
apply_transform_weight(
self.weight,
self.weight.to(x.device),
x.to(dtype=self.weight.dtype),
self.location,
self.module_type,
Expand Down Expand Up @@ -118,13 +118,13 @@ def _create_weight(
return nn.Parameter(data, requires_grad=False)


TRANSFORMS = {
HADAMARDS = {
"identity": IdentityTransform,
"hadamard": HadamardTransform,
"random_hadamard": RandomHadamardTransform,
}


def build_transform(transform_type: str, **transform_kwargs):
transform = TRANSFORMS[transform_type]
return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs))
def build_hadamard_transform(hadamard_type: str, **hadamard_kwargs):
hadamard = HADAMARDS[hadamard_type]
return hadamard(**filter_kwarg_dict(hadamard.__init__, hadamard_kwargs))
Loading