diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 9be537c96..3f46d40fc 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -90,7 +90,6 @@ def __new__( enable_alg_ext: bool = False, disable_opt_rtn: bool | None = None, low_cpu_mem_usage: bool = True, - transform_config: dict | None = None, **kwargs, ) -> BaseCompressor: """Initialize AutoRound with quantization and tuning configuration. diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 1db653783..50d1401ce 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -56,6 +56,7 @@ ) from auto_round.data_type import QUANT_FUNC_WITH_DTYPE from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, update_block_global_scale_if_needed +from auto_round.experimental.transform.hadamard_config import HadamardConfig from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG from auto_round.formats import OutputFormat, get_formats from auto_round.logger import logger @@ -150,7 +151,7 @@ "super_bits", "super_group_size", "to_quant_block_names", - "transform_config", + "hadamard_config", ) @@ -201,6 +202,7 @@ def __init__( disable_opt_rtn: bool | None = None, seed: int = 42, low_cpu_mem_usage: bool = True, + hadamard_config: str | dict | HadamardConfig | None = None, **kwargs, ): """Initialize AutoRound with quantization and tuning configuration. @@ -555,7 +557,18 @@ def __init__( except (ImportError, ModuleNotFoundError): logger.error("algorithm extension import error, fallback to default mode") - self.transform_config = kwargs.pop("transform_config", {}) + # apply hadamard transform + if hadamard_config: + from auto_round.experimental.transform.apply import apply_hadamard_transform + from auto_round.experimental.utils import check_supported_schemes, normalize_hadamard_config + + check_supported_schemes(self.scheme) + + self.model = apply_hadamard_transform( + self.model, hadamard_config, need_calibration=True if self.iters > 0 else False + ) + + self.hadamard_config = normalize_hadamard_config(hadamard_config) def _gen_auto_scheme(self) -> dict[str, dict]: if self.mllm: @@ -3365,6 +3378,7 @@ def save_quantized( serialization_dict = {} for key in SERIALIZATION_KEYS: serialization_dict[key] = getattr(self, key) + from auto_round.version import __version__ serialization_dict["autoround_version"] = __version__ diff --git a/auto_round/experimental/qmodules/__init__.py b/auto_round/experimental/qmodules/__init__.py index cf805645b..3862e0293 100644 --- a/auto_round/experimental/qmodules/__init__.py +++ b/auto_round/experimental/qmodules/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, TransformMXFP4QuantLinear +from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, HadamardMXFP4QuantLinear from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear diff --git a/auto_round/experimental/qmodules/mx.py b/auto_round/experimental/qmodules/mx.py index a5109b441..b5bc3e939 100644 --- a/auto_round/experimental/qmodules/mx.py +++ b/auto_round/experimental/qmodules/mx.py @@ -196,7 +196,7 @@ def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor: return unpacked_data -class TransformMXFP4QuantLinear(MXFP4QuantLinear): +class HadamardMXFP4QuantLinear(MXFP4QuantLinear): """ Quantized linear layer using the MXFP4 quantization scheme. """ @@ -206,7 +206,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.enable_transform = True self.register_buffer( - "transform_matrix", + "hadamard_matrix", torch.empty( self.group_size, self.group_size, diff --git a/auto_round/experimental/transform/apply.py b/auto_round/experimental/transform/apply.py index dd5289dcc..6980d75e4 100644 --- a/auto_round/experimental/transform/apply.py +++ b/auto_round/experimental/transform/apply.py @@ -5,35 +5,72 @@ import tqdm from auto_round.experimental.qmodules.mx import MXQuantLinearBase -from auto_round.experimental.transform.transform_config import TransformConfig -from auto_round.experimental.transform.transforms import build_transform +from auto_round.experimental.transform.hadamard_config import HadamardConfig +from auto_round.experimental.transform.hadamards import build_hadamard_transform +from auto_round.experimental.utils import is_triton_kernel_available, normalize_hadamard_config -__all__ = ["apply_transform"] +__all__ = ["apply_hadamard_transform"] -def apply_transform(model: torch.nn.Module, config: TransformConfig, use_tqdm=True, desc=None): +def apply_hadamard_transform( + model: torch.nn.Module, + config: str | dict | HadamardConfig | None, + need_calibration: bool = False, + location: str = "weight", + use_tqdm=True, + desc=None, +): """ - Apply a transform config to a model. Add weight transforms and - activation transforms are attached as submodules and trigger via pytorch hooks - - :param model: model to apply config to - :param config: transform config to apply + Apply a transform configuration to a model. + + Weight and activation transforms are attached as submodules and are + triggered via PyTorch hooks. + + :param model: Model to which the transform configuration will be applied. + :param config: Transform configuration to apply. Supported values are: + * ``str``: A named/preset transform configuration. In this case, + ``scheme`` is typically required so that the preset can be + resolved to a concrete quantization/transform configuration. + * ``dict``: A raw configuration mapping that will be normalized + (via :func:`normalize_hadamard_config`) and then passed to + :class:`TransformConfig`. + * :class:`TransformConfig`: An existing configuration instance. + This will be used to construct the final configuration after + normalization. + * ``None``: Uses the default behavior of + :func:`_normalize_hadamard_config` (for example, inferring a + configuration from ``scheme`` or other project defaults), if + supported. + :param scheme: Optional quantization/transform scheme identifier used + when ``config`` is a ``str`` (and, if supported, when it is + ``None``) to determine which concrete configuration to build. + Ignored when ``config`` is already a ``dict`` or + :class:`TransformConfig`. + :param use_tqdm: If ``True``, wrap the per-module application in a + tqdm progress bar. + :param desc: Optional description string to show in the tqdm progress + bar. If ``None``, a description will be derived from + ``config.transform_type``. """ + config = normalize_hadamard_config(config) + if not isinstance(config, HadamardConfig): + config = HadamardConfig(**config) + modules_config = [ (name, module, config) for name, module in model.named_modules() if isinstance(module, torch.nn.Linear) or isinstance(module, MXQuantLinearBase) ] - desc = f"Applying {config.transform_type} transforms" if desc is None else desc + desc = f"Applying {config.hadamard_type} transforms" if desc is None else desc for name, module, config in tqdm.tqdm(modules_config, desc=desc, disable=(not use_tqdm)): if "lm_head" in name: continue - _apply_to_module(model, module, config) + _apply_to_module(model, module, config, need_calibration, location) # attach config to model for compression/serialization - setattr(model, "transform_config", config) + setattr(model, "hadamard_config", config) return model @@ -41,7 +78,9 @@ def apply_transform(model: torch.nn.Module, config: TransformConfig, use_tqdm=Tr def _apply_to_module( model: torch.nn.Module, module: torch.nn.Module, - config: TransformConfig, + config: HadamardConfig, + need_calibration: bool = False, + location: str = "weight", ): """ Create transforms and apply them to the module @@ -51,23 +90,23 @@ def _apply_to_module( """ # create transform as submodule - transform_name = "transform_matrix" + hadamard_name = config.hadamard_type - if config.location == "input": - from auto_round.experimental.transform.triton.utils import is_triton_kernel_available + if location == "input": # activation needs transpose - inp_transform = build_transform( + input_hadamard_transform = build_hadamard_transform( **config.dict(), + location="input", inverse=True, device="cpu", precision=module.dtype, ) - if config.transform_type != "random_hadamard": - transform_weight = inp_transform.weight + if config.hadamard_type != "random_hadamard": + hadamard_weight = input_hadamard_transform.weight else: - transform_weight = None + hadamard_weight = None if is_triton_kernel_available(): from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper @@ -80,7 +119,7 @@ def input_hook(self, args): qdq_input, _ = mxfp4_forward_kernel_wrapper( x_flat, ( - transform_weight if transform_weight is not None else self.transform_matrix.T + hadamard_weight if hadamard_weight is not None else self.hadamard_matrix.T ), # this matrix from w_transform, needs transpose ) return qdq_input.reshape(orig_shape) @@ -97,37 +136,38 @@ def input_hook(self, args): ori_shape = input.shape - if transform_weight is not None: - input = input.view(-1, transform_weight.shape[0]) - return _multihead_matmul(input, transform_weight.to(input.device)).view(ori_shape) + if hadamard_weight is not None: + input = input.view(-1, hadamard_weight.shape[0]) + return _multihead_matmul(input, hadamard_weight.to(input.device)).view(ori_shape) else: - input = input.view(-1, self.transform_matrix.shape[0]) - return _multihead_matmul(input, self.transform_matrix.T).view(ori_shape) + input = input.view(-1, self.hadamard_matrix.shape[0]) + return _multihead_matmul(input, self.hadamard_matrix.T).view(ori_shape) # for fused transform + quantization kernel module.pre_dequantized_input = False module.register_forward_pre_hook(input_hook, prepend=True) - elif config.location == "weight": + elif location == "weight": # eagerly apply transformation to weight # fuse transform into weight assert hasattr(module, "weight") - w_transform = build_transform( + weight_hadamard_transform = build_hadamard_transform( **config.dict(), + location="weight", device=module.weight.device, precision=module.weight.dtype, ) # need save random hadamard matrix needed when inference - if config.transform_type == "random_hadamard": - module.register_module(transform_name, w_transform) + if config.hadamard_type == "random_hadamard": + module.register_module(config.hadamard_type, weight_hadamard_transform) # for saving transform weight from auto_round.experimental.transform.patch_modules import patch_quantlinear - patch_quantlinear() + patch_quantlinear(config.hadamard_type) - if config.need_calibration: + if need_calibration: # for training, the weight changes with every forward pass # for autoround tuning: patch wrapper linear qdq_weight func from auto_round.experimental.transform.patch_modules import ( @@ -135,7 +175,7 @@ def input_hook(self, args): patch_wrapperwalayer_forward_to_apply_transform, ) - inp_transform = build_transform( + input_hadamard_transform = build_hadamard_transform( **config.dict(), location="input", inverse=True, @@ -143,15 +183,15 @@ def input_hook(self, args): precision=module.weight.dtype, ) - patch_wrapperlinear_to_apply_transform(w_transform, inp_transform) - patch_wrapperwalayer_forward_to_apply_transform(inp_transform) + patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform) + patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform) else: # transform is no longer needed (unfusing is not supported) # delattr(module, transform_name) # fuse transform into weight with torch.no_grad(): - getattr(module, "weight").copy_(w_transform(module.weight).to(module.weight.device)) + getattr(module, "weight").copy_(weight_hadamard_transform(module.weight).to(module.weight.device)) else: # TODO: apply transform to output/q/k diff --git a/auto_round/experimental/transform/hadamard_config.py b/auto_round/experimental/transform/hadamard_config.py new file mode 100644 index 000000000..b58cce9dc --- /dev/null +++ b/auto_round/experimental/transform/hadamard_config.py @@ -0,0 +1,28 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +from pydantic import BaseModel, Field, field_validator + +__all__ = ["HadamardConfig"] + + +class HadamardConfig(BaseModel): + """ + Configuration of transforms to be applied to a model. This config is to be + serialized within a model's `config.json` file + """ + + block_size: int = Field(default=32) + + hadamard_type: str = Field(default="hadamard") + + # for random hadamard transform + random_seed: bool = Field(default=False, exclude=True) + + @field_validator("hadamard_type") + @classmethod + def validate_hadamard_type(cls, v: str) -> str: + allowed = {"hadamard", "random_hadamard"} + if v not in allowed: + raise ValueError(f"Unsupported hadamard_type: {v}. Supported values: {sorted(allowed)}") + return v diff --git a/auto_round/experimental/transform/transforms.py b/auto_round/experimental/transform/hadamards.py similarity index 86% rename from auto_round/experimental/transform/transforms.py rename to auto_round/experimental/transform/hadamards.py index 38bb1acb7..712232a9a 100644 --- a/auto_round/experimental/transform/transforms.py +++ b/auto_round/experimental/transform/hadamards.py @@ -28,23 +28,11 @@ def filter_kwarg_dict(fn_or_method: Callable, kwarg_dict: Dict[str, Any]) -> Dic return {k: v for k, v in kwarg_dict.items() if k in fn_or_method_keys} -class IdentityTransform(nn.Module): - - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, x: torch.Tensor): - return x - - def remove_parametrizations(self) -> None: - pass - - class HadamardTransform(nn.Module): def __init__( self, - transform_block_size: int = 32, + block_size: int = 32, device: torch.device = None, precision: torch.dtype = None, location: str = "weight", @@ -52,7 +40,7 @@ def __init__( inverse: bool = False, ): super().__init__() - self.size = transform_block_size + self.size = block_size self.scale = 1 / math.sqrt(self.size) self.location = location self.module_type = module_type @@ -76,7 +64,7 @@ def forward(self, x: torch.Tensor): return ( ( apply_transform_weight( - self.weight, + self.weight.to(x.device), x.to(dtype=self.weight.dtype), self.location, self.module_type, @@ -118,13 +106,12 @@ def _create_weight( return nn.Parameter(data, requires_grad=False) -TRANSFORMS = { - "identity": IdentityTransform, +HADAMARDS = { "hadamard": HadamardTransform, "random_hadamard": RandomHadamardTransform, } -def build_transform(transform_type: str, **transform_kwargs): - transform = TRANSFORMS[transform_type] - return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs)) +def build_hadamard_transform(hadamard_type: str, **hadamard_kwargs): + hadamard = HADAMARDS[hadamard_type] + return hadamard(**filter_kwarg_dict(hadamard.__init__, hadamard_kwargs)) diff --git a/auto_round/experimental/transform/patch_modules.py b/auto_round/experimental/transform/patch_modules.py index cf15b46e0..934ebea9d 100644 --- a/auto_round/experimental/transform/patch_modules.py +++ b/auto_round/experimental/transform/patch_modules.py @@ -139,7 +139,7 @@ def _forward_patched(self, x): WrapperWALayer._hadamard_forward_patched = True -def patch_quantlinear(): +def patch_quantlinear(hadamard_type): """ """ if getattr(QuantLinear, "_pack_patched", False): @@ -202,8 +202,8 @@ def _pack_patched( self.input_global_scale = input_global_scale.to(torch.float32).to(device).reshape([1]) # add transform weight - transform = getattr(linear, "transform_matrix") - self.register_buffer("transform_matrix", transform.weight.to(device)) + transform = getattr(linear, hadamard_type) + self.register_buffer("hadamard_matrix", transform.weight.to(device)) return QuantLinear.pack = _pack_patched diff --git a/auto_round/experimental/transform/transform_config.py b/auto_round/experimental/transform/transform_config.py deleted file mode 100644 index 91713666b..000000000 --- a/auto_round/experimental/transform/transform_config.py +++ /dev/null @@ -1,43 +0,0 @@ -# # Copyright (C) 2026 Intel Corporation -# # SPDX-License-Identifier: Apache-2.0 - -from pydantic import BaseModel, Field, field_validator - -__all__ = ["TransformConfig"] - - -class TransformConfig(BaseModel): - """ - Configuration of transforms to be applied to a model. This config is to be - serialized within a model's `config.json` file - """ - - # required, currently only supports mxfp4 - quant_scheme: str = Field(..., description="Quantization scheme. Currently supports 'MXFP4/MXFP8'.") - - transform_block_size: int = Field(default=32) - - transform_type: str = Field(default="hadamard") - - location: str = Field(default="weight", exclude=True) - - # apply transform inside modules for nvfp4, autoround tuning etc. - need_calibration: bool = Field(default=False, exclude=True) - - # for random hadamard transform - random_seed: bool = Field(default=False, exclude=True) - - @field_validator("quant_scheme") - @classmethod - def validate_quant_scheme(cls, v: str) -> str: - if v not in ["MXFP4", "MXFP8"]: - raise ValueError(f"Unsupported quant_scheme: {v}. Currently 'mxfp4/mxfp8' are supported.") - return v - - @field_validator("transform_type") - @classmethod - def validate_transform_type(cls, v: str) -> str: - allowed = {"hadamard", "random_hadamard"} - if v not in allowed: - raise ValueError(f"Unsupported transform_type: {v}. Supported values: {sorted(allowed)}") - return v diff --git a/auto_round/experimental/transform/triton/utils.py b/auto_round/experimental/transform/triton/utils.py deleted file mode 100644 index 8aa5da743..000000000 --- a/auto_round/experimental/transform/triton/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -# # Copyright (C) 2026 Intel Corporation -# # SPDX-License-Identifier: Apache-2.0 - -import torch - - -def is_triton_kernel_available() -> bool: - """ - Best-effort check for whether Triton kernel path can be used. - """ - try: - import triton # pylint: disable=E0401 - except Exception: - return False - - if not torch.cuda.is_available(): - return False - - try: - from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper # pylint: disable=E0401 - except Exception: - return False - - return True diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index e90f9c0d5..39a7ff135 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch +from auto_round.experimental.transform.hadamard_config import HadamardConfig +from auto_round.experimental.transform.hadamards import HADAMARDS from auto_round.utils import logger +SUPPORTED_QUANTIZATION_SCHEMES = ["MXFP4"] + def per_tensor_fp8_qdq( tensor: torch.Tensor, tensor_max: None | torch.Tensor = None @@ -106,3 +112,89 @@ def clean_model_parameters_and_buffers_(model: torch.nn.Module, name_tuple: tupl """ for module in model.modules(): _clean_param_or_buff_if_exists(module, name_tuple) + + +def is_triton_kernel_available() -> bool: + """ + Best-effort check for whether Triton kernel path can be used. + """ + try: + import triton # pylint: disable=E0401 + except Exception: + return False + + if not torch.cuda.is_available(): + return False + + try: + from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper # pylint: disable=E0401 + except Exception: + return False + + return True + + +def normalize_hadamard_config(hadamard_config: str | dict | HadamardConfig | None) -> dict[str, Any]: + """ + Normalize and validate `hadamard_config`. + + Supported input types: + - None -> {} + - dict -> validated via HadamardConfig + - HadamardConfig -> validated & converted to dict + - str -> shorthand for `transform_type` in TRANSFORMS keys + + On any validation failure, raises ValueError/TypeError. + """ + # 1) None -> {} + if hadamard_config is None: + return {} + + # 2) Already a HadamardConfig instance + if isinstance(hadamard_config, HadamardConfig): + # Ensure it passes its own validation and convert to dict + cfg = HadamardConfig.model_validate(hadamard_config).model_dump() + return cfg + + # 3) dict -> validate via HadamardConfig + if isinstance(hadamard_config, dict): + try: + cfg = HadamardConfig.model_validate(hadamard_config).model_dump() + except Exception as e: + raise ValueError(f"Invalid hadamard_config dict: {e}") from e + return cfg + + # 4) str -> shorthand for transform_type + if isinstance(hadamard_config, str): + key = hadamard_config.strip() + if not key: + return {} + + if key == "default": + cfg = HadamardConfig() + return cfg.model_dump() + + if key not in HADAMARDS: + raise ValueError( + f"Invalid hadamard_config string: {key!r}. " f"Expected one of {sorted(HADAMARDS.keys())}." + ) + + cfg_dict = {"hadamard_type": key} + + try: + cfg = HadamardConfig.model_validate(cfg_dict).model_dump() + except Exception as e: + raise ValueError(f"hadamard_config built from string {key!r} is invalid for HadamardConfig: {e}") from e + + return cfg + + raise TypeError( + "hadamard_config must be one of: None, dict, HadamardConfig, or str " f"(got {type(hadamard_config).__name__})" + ) + + +def check_supported_schemes(scheme: str): + if scheme not in SUPPORTED_QUANTIZATION_SCHEMES: + raise ValueError( + f"Unsupported quantization scheme: {scheme}. Currently {SUPPORTED_QUANTIZATION_SCHEMES} are supported." + ) diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index 00e2950b3..d98545679 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -775,10 +775,10 @@ def dynamic_import_inference_linear(backend, config): if "torch_mxfp8" in backend: return ar_qmodules.MXFP8QuantLinear if "torch_mxfp4" in backend: - transform_config = getattr(config, "transform_config", None) - if transform_config is not None and transform_config: - if transform_config["transform_type"] == "random_hadamard": - return ar_qmodules.TransformMXFP4QuantLinear + hadamard_config = getattr(config, "hadamard_config", None) + if hadamard_config is not None and hadamard_config: + if hadamard_config["hadamard_type"] == "random_hadamard": + return ar_qmodules.HadamardMXFP4QuantLinear return ar_qmodules.MXFP4QuantLinear if "torch_nvfp4" in backend: return ar_qmodules.NVFP4QuantLinear diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index afea71b80..a5b9096b3 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -211,7 +211,7 @@ def get_layer_config(model, quantization_config): act_data_type = getattr(quantization_config, "act_data_type", None) act_dynamic = getattr(quantization_config, "act_dynamic", False) - transform_config = getattr(quantization_config, "transform_config", None) + hadamard_config = getattr(quantization_config, "hadamard_config", None) default_quant_scheme = QuantizationScheme( bits=bits, @@ -223,7 +223,7 @@ def get_layer_config(model, quantization_config): act_sym=act_sym, act_data_type=act_data_type, act_dynamic=act_dynamic, - transform_config=transform_config, + hadamard_config=hadamard_config, ) # Determine the quantization block list @@ -676,19 +676,19 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M layer_configs = get_layer_config(model, quantization_config) used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, packing_format) - transform_config = getattr(quantization_config, "transform_config", None) - if transform_config is not None and transform_config: - from auto_round.experimental.transform.apply import apply_transform - from auto_round.experimental.transform.transform_config import TransformConfig + hadamard_config = getattr(quantization_config, "hadamard_config", None) + if hadamard_config is not None and hadamard_config: + from auto_round.experimental.transform.apply import apply_hadamard_transform + from auto_round.experimental.transform.hadamard_config import HadamardConfig # apply forward hook - act_transform_config = TransformConfig( - quant_scheme=transform_config["quant_scheme"], - transform_block_size=transform_config["transform_block_size"], - transform_type=transform_config["transform_type"], - location="input", + act_hadamard_config = HadamardConfig( + block_size=hadamard_config["block_size"], + hadamard_type=hadamard_config["hadamard_type"], ) # apply to activation - model = apply_transform(model, act_transform_config, desc="Register pre forward hook for transform") + model = apply_hadamard_transform( + model, act_hadamard_config, location="input", desc="Register pre forward hook for hadamard transform" + ) # Suggest a better backend if available if backend == "auto": diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 49fc6fa76..5318b1fec 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -36,7 +36,7 @@ class QuantizationScheme: act_dynamic: Optional[bool] = None super_bits: Optional[int] = None super_group_size: Optional[int] = None - transform_config: Optional[dict] = None + hadamard_config: Optional[dict] = None @classmethod def from_dict(cls, config: dict): @@ -88,7 +88,7 @@ def __eq__(self, other: "QuantizationScheme") -> bool: continue self_val = getattr(self, field) other_val = getattr(other, field) - # Treat None and empty dict as equivalent for dict fields like transform_config + # Treat None and empty dict as equivalent for dict fields like hadamard_config if self_val != other_val: if isinstance(self_val, dict) and not self_val and other_val is None: continue diff --git a/docs/step_by_step.md b/docs/step_by_step.md index b8fa54992..36e6db7aa 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -32,6 +32,7 @@ This document presents step-by-step instructions for auto-round llm quantization + [Device/Multi-GPU setting in Quantization](#devicemulti-gpu-setting-in-quantization) - [Enable multiple gpus calibration in lm_head quantization](#enable-multiple-gpus-calibration-in-lm_head-quantization) + [Adjust Hyperparameters](#adjust-hyperparameters) + + [Hadamard Transform](#hadamard-transform) * [4 Inference](#4-inference) + [CPU](#cpu) + [Intel GPU](#intel-gpu) @@ -621,6 +622,53 @@ autoround.save_quantized(format="auto_awq", output_dir="tmp_autoround") Include the flag `--adam`. Note that AdamW is less effective than sign gradient descent in many scenarios we tested. +### Hadamard Transform + +AutoRound supports Hadamard transform as an optional weight/activation transformation technique, which can improve quantization accuracy by rotating the weight/activation matrix. This is particularly useful for certain quantization scenarios. + +#### Overview + +The Hadamard transform is particularly useful in scenarios where activation outliers hurt quantization accuracy. In practice, it helps suppress such outliers, making it especially effective when `act_bits < 8`. Users can enable this feature when they need more stable activation distributions and better accuracy in low‑bit quantization settings. + +#### Implementation + +AutoRound provides two types of Hadamard transforms: + +1. **Deterministic Hadamard Transform** (`hadamard`): Uses Sylvester's construction to create a deterministic Hadamard matrix. The size must be a power of 2. + +2. **Random Hadamard Transform** (`random_hadamard`): Uses known Hadamard matrices from N. J. A. Sloane's Library of Hadamard Matrices. Supports non-power-of-2 sizes and deterministic seeding. + + +#### Quantization with Hadamard Transform + +```python +from auto_round import AutoRound + +# Load a model (supports FP8/BF16/FP16/FP32) +model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct" +output_dir = "./Llama-3.1-8B-Instruct-mxfp4-ht" + +# hadamard_config="default": block_size=32, hadamard_type="hadamard" +ar = AutoRound(model_name_or_path, scheme="MXFP4", hadamard_config="default") + +ar.quantize_and_save(output_dir=output_dir, format="auto_round") +``` + +#### Transform Classes + +| Class | Description | +|-------|-------------| +| `HadamardTransform` | Applies deterministic Hadamard transform | +| `RandomHadamardTransform` | Applies random Hadamard transform with optional seeding | + +#### Parameters + +| Parameter | Description | +|-----------|-------------| +| `block_size` | Size of the transformation block (default: 32) | +| `seed` | Random seed (for RandomHadamardTransform) | + + ## 4 Inference AutoRound automatically selects the best available backend based on the installed libraries and prompts the user to install additional libraries when a better backend is found. diff --git a/docs/step_by_step_CN.md b/docs/step_by_step_CN.md index a3097ddca..8a1327e78 100644 --- a/docs/step_by_step_CN.md +++ b/docs/step_by_step_CN.md @@ -33,6 +33,7 @@ - [lm_head 量化中开启多 GPU 标定](#lm_head-量化中开启多-gpu-标定) - [手动配置设备映射](#手动配置设备映射) + [超参数调整](#超参数调整) + + [Hadamard变换](#hadamard变换) * [4 推理部署](#4-推理部署) + [CPU](#cpu) + [英特尔 GPU](#英特尔-gpu) @@ -584,6 +585,51 @@ auto-round --model_name Qwen/Qwen3-0.6B --scheme "W4A16" --quant_lm_head --form +### Hadamard变换 + +AutoRound 支持将 Hadamard 变换作为可选的权重/激活变换技术,通过旋转权重/激活矩阵来提升量化精度。这在某些量化场景中尤其有用。 + +#### 概述 + +Hadamard 变换在激活值存在离群点且影响量化精度的场景中特别有效。在实践中,它能够抑制这些离群点,因此在 `act_bits < 8` 的低比特激活量化设置中效果尤为显著。用户可以在需要更稳定的激活分布和更高低比特精度时启用该功能。 + +#### 实现方式 + +AutoRound 提供两种类型的 Hadamard 变换: + +1. **确定性 Hadamard 变换**(`hadamard`):使用 Sylvester 构造法生成确定性的 Hadamard 矩阵,尺寸必须为 2 的幂次。 +2. **随机 Hadamard 变换**(`random_hadamard`):使用 N. J. A. Sloane 的 Hadamard 矩阵库中已知的矩阵。支持非 2 的幂次的尺寸,并支持确定性随机种子。 + +#### 使用 Hadamard 变换进行量化 + +```python +from auto_round import AutoRound + +# 加载模型(支持 FP8/BF16/FP16/FP32) +model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct" +output_dir = "./Llama-3.1-8B-Instruct-mxfp4-ht" + +# hadamard_config="default": block_size=32, hadamard_type="hadamard" +ar = AutoRound(model_name_or_path, scheme="MXFP4", hadamard_config="default") + +ar.quantize_and_save(output_dir=output_dir, format="auto_round") +``` + +#### 变换类 + +| 类名 | 描述 | +| ------------------------- | --------------------------- | +| `HadamardTransform` | 应用确定性 Hadamard 变换 | +| `RandomHadamardTransform` | 应用随机 Hadamard 变换,并可选随机种子 | + +#### 参数说明 + +| 参数 | 描述 | +| ------------ | -------------------------------- | +| `block_size` | 变换块大小(默认:32) | +| `seed` | 随机种子(用于 RandomHadamardTransform) | + + ## 4 推理部署 AutoRound 支持十余种推理后端,并会根据已安装的库自动选择最优后端;如果检测到系统中存在更优后端但缺少相关依赖,也会主动提示用户安装。 diff --git a/setup.py b/setup.py index dff0d7a8f..1b759fe83 100644 --- a/setup.py +++ b/setup.py @@ -186,5 +186,5 @@ def fetch_requirements(path): "License :: OSI Approved :: Apache Software License", ], include_package_data=True, - package_data={"": ["mllm/templates/*.json"]}, + package_data={"": ["mllm/templates/*.json", "experimental/transform/utils/hadamards.safetensors"]}, ) diff --git a/test/test_cuda/transform/test_mxfp4_transform.py b/test/test_cuda/transform/test_mxfp4_transform.py index 8bbf52aaf..fba84e678 100644 --- a/test/test_cuda/transform/test_mxfp4_transform.py +++ b/test/test_cuda/transform/test_mxfp4_transform.py @@ -31,36 +31,53 @@ def test_transform_mxfp4_quant_infer(self): model_name = get_model_path("qwen/Qwen3-0.6B") scheme = "MXFP4" - from auto_round.utils import llm_load_model - - model, tokenizer = llm_load_model( - model_name, - platform="hf", - device="cpu", # always load cpu first - model_dtype=None, - trust_remote_code=True, + ar = AutoRound( + model=model_name, + iters=0, + seqlen=2, + scheme=scheme, + hadamard_config="default", ) + compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") + + model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + from ...helpers import generate_prompt - from auto_round.experimental.transform.apply import apply_transform - from auto_round.experimental.transform.transform_config import TransformConfig + generate_prompt(model, tokenizer) - transform_config = TransformConfig(quant_scheme="MXFP4") - model = apply_transform( - model, - transform_config, + def test_transform_mxfp4_tuning_quant_infer(self): + model_name = get_model_path("qwen/Qwen3-0.6B") + scheme = "MXFP4" + + ar = AutoRound( + model=model_name, + iters=2, + seqlen=2, + scheme=scheme, + hadamard_config="default", ) + compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") + + model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", device_map="cuda") + tokenizer = AutoTokenizer.from_pretrained(self.save_dir) + from ...helpers import generate_prompt + + generate_prompt(model, tokenizer) + + def test_random_transform_mxfp4_quant_infer(self): + model_name = get_model_path("qwen/Qwen3-0.6B") + scheme = "MXFP4" ar = AutoRound( - model=model, + model=model_name, iters=0, seqlen=2, scheme=scheme, - transform_config=transform_config.dict(), + hadamard_config="random_hadamard", ) compressed_model, _ = ar.quantize_and_save(output_dir=self.save_dir, format="auto_round") - tokenizer.save_pretrained(self.save_dir) - model = AutoModelForCausalLM.from_pretrained(self.save_dir, torch_dtype="auto", device_map="cuda") tokenizer = AutoTokenizer.from_pretrained(self.save_dir) from ...helpers import generate_prompt