From babea8f1508f3099778863e68bc72aeb2d517a6b Mon Sep 17 00:00:00 2001 From: realAsma Date: Wed, 15 Oct 2025 15:36:51 +0000 Subject: [PATCH 1/3] [1/2] Registry interface for custom quantization functional backend Signed-off-by: realAsma minor minor moved external changes to this PR addressed PR comments; clean ups minor test fix rebasing external changes from [2/2] minor unrelated fix Signed-off-by: realAsma --- modelopt/torch/quantization/config.py | 35 +++++- modelopt/torch/quantization/model_quant.py | 2 +- .../nn/modules/tensor_quantizer.py | 100 +++++++++++++++--- .../torch/quantization/plugins/megatron.py | 8 ++ modelopt/torch/quantization/tensor_quant.py | 65 +++++------- modelopt/torch/utils/tensor.py | 12 +++ .../torch/quantization/test_custom_backend.py | 53 ++++++++++ 7 files changed, 221 insertions(+), 54 deletions(-) create mode 100644 tests/unit/torch/quantization/test_custom_backend.py diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 0642ed910..3458be6e8 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -665,7 +665,7 @@ class QuantizerAttributeConfig(ModeloptBaseConfig): description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""", ) - num_bits: int | tuple[int, int] = ModeloptField( + num_bits: int | tuple[int, int] | str = ModeloptField( default=8, title="An integer or a tuple of two integers specifying the number of quantization bits.", description="""`num_bits` can be: @@ -675,7 +675,9 @@ class QuantizerAttributeConfig(ModeloptBaseConfig): #. Constant integer tuple (E,M) for floating point quantization emulating Nvidia's FPx quantization. E is the number of exponent bits and M is the number - of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).""", + of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1). + + #. String specifying the quantization format. This is current used only for custom backends.""", ) @model_validator(mode="before") @@ -707,10 +709,16 @@ def _validate_recursive(value): @model_validator(mode="after") def validate_num_bits(self): """Validate `num_bits`.""" + if self.backend is not None: + # For custom backends, we don't need to validate num_bits + return self + num_bits = self.num_bits if isinstance(num_bits, int) and num_bits < 1: - raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") + raise ValueError( + f"num_bits must be a positive integer or a tuple of positive integers. {num_bits}" + ) if not isinstance(num_bits, tuple): return self @@ -952,6 +960,27 @@ def validate_calibrator(cls, v, info: ValidationInfo): """, ) + backend: str | None = ModeloptField( + default=None, + title="Name of custom quantization functional backend.", + description=""" + Selects a non-default quantization functional backend by name. See + :meth:`register_quant_backend ` + for more details on how to register a custom quantization backend. + """, + ) + backend_extra_args: dict | None = ModeloptField( + default=None, + title="Extra arguments for the selected backend.", + description="""The extra arguments will saved on to the quantizer instance - this wont be + passed directly to the backend entrypoint. Can be any serializable dictionary. + + Please use `backend_extra_args` to pass arguments that are not already supported by + `QuantizerAttributeConfig`. This will ensure maximum compatibility with the other modelopt + features such as modelopt's calibration algorithms. + """, + ) + class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index deace8e0c..3b4e6eff5 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -228,7 +228,7 @@ def forward_loop(model) -> None: Returns: A pytorch model which has been quantized and calibrated. """ model = apply_mode(model, mode=[("quantize", config)], registry=QuantizeModeRegistry) - return calibrate(model, config["algorithm"], forward_loop=forward_loop) + return calibrate(model, config.get("algorithm"), forward_loop=forward_loop) def auto_quantize( diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index bf801646b..e19f34559 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -18,7 +18,8 @@ import contextlib import math import warnings -from typing import TYPE_CHECKING, Any +from collections.abc import Callable +from typing import Any import torch import torch.distributed as dist @@ -36,7 +37,7 @@ import torch.nn.functional as F from torch import nn -from modelopt.torch.utils import standardize_constructor_args +from modelopt.torch.utils import same_device_as, standardize_constructor_args from modelopt.torch.utils.distributed import DistributedProcessGroup from ... import calib @@ -56,10 +57,58 @@ from ...utils import is_torch_export_mode from ..functional import normalized_hadamard_transform -if TYPE_CHECKING: - from collections.abc import Callable +__all__ = [ + "SequentialQuantizer", + "TensorQuantizer", + "is_registered_quant_backend", + "register_quant_backend", + "unregister_quant_backend", +] -__all__ = ["SequentialQuantizer", "TensorQuantizer"] + +QuantBackendEntrypoint = Callable[[torch.Tensor, "TensorQuantizer"], torch.Tensor] + +_QUANT_FUNCTIONAL_BACKENDS: dict[str, QuantBackendEntrypoint] = {} + + +def register_quant_backend(name: str, entrypoint: QuantBackendEntrypoint) -> None: + """Register a custom quantization backend. + + Args: + name: The name of the backend. + entrypoint: The entrypoint of the backend. The entrypoint should be a callable that takes in + the inputs and the tensor quantizer as arguments and returns the quantized tensor. + See :class:`modelopt.torch.quantization.config.QuantizerAttributeConfig` + for details on choosing from the registered backends via the ``backend`` and + ``backend_extra_args`` fields. + """ + if not isinstance(name, str) or not name: + raise ValueError("Backend name must be a non-empty string.") + if not callable(entrypoint): + raise TypeError("Entrypoint must be callable.") + if name in _QUANT_FUNCTIONAL_BACKENDS: + warnings.warn(f"Overwriting existing backend: {name}") + _QUANT_FUNCTIONAL_BACKENDS[name] = entrypoint + + +def unregister_quant_backend(name: str) -> None: + """Unregister a custom quantization backend. + + Args: + name: The name of the backend to unregister. + """ + if not isinstance(name, str) or not name: + raise ValueError("Backend name must be a non-empty string.") + _QUANT_FUNCTIONAL_BACKENDS.pop(name, None) + + +def is_registered_quant_backend(name: str) -> bool: + """Check if a custom quantization backend is registered. + + Args: + name: The name of the backend to check. + """ + return name in _QUANT_FUNCTIONAL_BACKENDS class TensorQuantizer(nn.Module): @@ -153,6 +202,8 @@ def _calibrator_setter(val): "enable": ("_disabled", lambda val: val is False), "type": ("_dynamic", lambda val: val == "dynamic"), "calibrator": ("_calibrator", _calibrator_setter), + "backend": ("backend", lambda val: val), + "backend_extra_args": ("backend_extra_args", lambda val: val or {}), } for attribute, val in attribute_cfg.items(): @@ -621,6 +672,12 @@ def _real_quantize(self, inputs): def _fake_quantize(self, inputs): """Fake quantization.""" + if self.backend is not None: + if self.backend not in _QUANT_FUNCTIONAL_BACKENDS: + raise KeyError(f"Quant backend '{self.backend}' is not registered.") + entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend] + return entrypoint(inputs, self) + amax = None if not self.is_mx_format: amax = self._get_amax(inputs) @@ -927,7 +984,8 @@ def forward(self, inputs): if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): inputs.data = inputs.data.contiguous() if self.fake_quant: - outputs = self._fake_quantize(inputs) + with same_device_as(inputs): + outputs = self._fake_quantize(inputs) elif not self._dequantize: outputs = self._real_quantize(inputs) else: @@ -961,16 +1019,23 @@ def _short_amax(self, fmt=".4f"): return "None" if self._amax.is_meta: return "meta" - if self._amax.numel() == 1: - return f"{self._amax.item():{fmt}}" - return ( - f"[{self._amax.min().item():{fmt}}," - f" {self._amax.max().item():{fmt}}]({self._amax.numel()})" - ) + return self._short_tensor(self._amax, fmt) + + def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"): + """Short description of tensor.""" + if tensor.numel() == 1: + return f"{tensor.item():{fmt}}" + return f"[{tensor.min().item():{fmt}}, {tensor.max().item():{fmt}}]({tensor.numel()})" def extra_repr(self): """Set the extra information about this module.""" if self._disabled: + s = "disabled" + s += ( + f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" + if self.pre_quant_scale is not None + else "" + ) return "disabled" s = f"{'unsigned ' if self._unsigned else ''}{self._num_bits} bit" s += " narrow" if (self._narrow_range) else "" @@ -980,7 +1045,11 @@ def extra_repr(self): else: s += f" axis={self._axis}" if self._axis is not None else " per-tensor" s += f" amax={self._short_amax()}" - s += " pre_quant_scale" if self.pre_quant_scale is not None else "" + s += ( + f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" + if self.pre_quant_scale is not None + else "" + ) s += " rotated" if self._rotate else "" s += ( f" calibrator={self._calibrator.__class__.__name__}" @@ -992,6 +1061,11 @@ def extra_repr(self): s += " quant" if (self._if_quant) else "" s += " calib" if (self._if_calib) else "" + s += ( + f" backend={self.backend}, extra_args={self.backend_extra_args}" + if self.backend is not None + else "" + ) return s def _get_properties_for_modelopt_state(self): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 85784d2fe..77cd72093 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -231,6 +231,14 @@ def _setup(self): data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), ) + + if getattr(self, "gradient_accumulation_fusion", False): + warnings.warn( + "gradient_accumulation_fusion is not supported with ModelOpt quantization. " + "Setting gradient_accumulation_fusion to False." + ) + self.gradient_accumulation_fusion = False + super()._setup() def _process_quantizer_amax(self, k, v, quantizer_state_dict): diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 5f69e3999..7c35af75e 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -79,14 +79,11 @@ def scaled_e4m3_impl( if cuda_ext_fp8 is None: return fp8_eager(inputs, amax) - with torch.cuda.device( - None if inputs.device.index == torch.cuda.current_device() else inputs.device.index - ): - if amax.numel() == 1: - outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) - elif amax.squeeze().ndim == 1: - axis = amax.shape.index(amax.numel()) - outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) + if amax.numel() == 1: + outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) + elif amax.squeeze().ndim == 1: + axis = amax.shape.index(amax.numel()) + outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) return outputs @@ -100,17 +97,14 @@ def fake_quant_impl( """Implementation of fake quantizing input according to number of bits.""" cuda_ext = get_cuda_ext() - with torch.cuda.device( - None if inputs.device.index == torch.cuda.current_device() else inputs.device.index - ): - if amax.numel() == 1: - outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) - else: - axis = amax.shape.index(amax.numel()) - outputs = cuda_ext.fake_tensor_quant_with_axis( - inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range - ) - return outputs + if amax.numel() == 1: + outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) + else: + axis = amax.shape.index(amax.numel()) + outputs = cuda_ext.fake_tensor_quant_with_axis( + inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range + ) + return outputs def _quantize_impl( @@ -173,25 +167,22 @@ def _dynamic_block_quantize_impl( assert amax.is_cuda, "amax must be a CUDA tensor for dynamic block quantization." if amax.numel() != 1: amax = amax.amax() - with torch.cuda.device( - None if inputs.device.index == torch.cuda.current_device() else inputs.device.index + if ( + num_bits == (2, 1) # type: ignore[comparison-overlap] + and scale_bits == (4, 3) + and triton_kernel.IS_AVAILABLE + and not DISABLE_TRITON_KERNEL + and amax is not None ): - if ( - num_bits == (2, 1) # type: ignore[comparison-overlap] - and scale_bits == (4, 3) - and triton_kernel.IS_AVAILABLE - and not DISABLE_TRITON_KERNEL - and amax is not None - ): - return triton_kernel.fp4_fake_quant_block(inputs, amax) - cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True) - return cuda_ext_mx.fused_amax_convert( - inputs, - block_size, - getattr(cuda_ext_mx.Types, mx_format_map[num_bits]), - getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]), - amax, - ) + return triton_kernel.fp4_fake_quant_block(inputs, amax) + cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True) + return cuda_ext_mx.fused_amax_convert( + inputs, + block_size, + getattr(cuda_ext_mx.Types, mx_format_map[num_bits]), + getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]), + amax, + ) else: raise NotImplementedError( f"Unsupported num_bits: {num_bits}, scale_bits: {scale_bits} for dynamic block quantization." diff --git a/modelopt/torch/utils/tensor.py b/modelopt/torch/utils/tensor.py index 3ff208628..71f801f35 100644 --- a/modelopt/torch/utils/tensor.py +++ b/modelopt/torch/utils/tensor.py @@ -16,12 +16,14 @@ """Utility functions for PyTorch tensors.""" from collections import abc +from contextlib import nullcontext import numpy as np import torch __all__ = [ "numpy_to_torch", + "same_device_as", "to_empty_if_meta_device", "torch_detach", "torch_to", @@ -29,6 +31,16 @@ ] +def same_device_as(inputs: torch.Tensor): + """Return a context manager that sets the CUDA device to be the same as the input tensor. + + Returns a null context if the tensor is on CPU or on the same device as the current CUDA device. + """ + if not inputs.is_cuda or inputs.device.index == torch.cuda.current_device(): + return nullcontext() + return torch.cuda.device(inputs.device.index) + + def torch_to(data, *args, **kwargs): """Try to recursively move the data to the specified args/kwargs.""" if isinstance(data, torch.Tensor): diff --git a/tests/unit/torch/quantization/test_custom_backend.py b/tests/unit/torch/quantization/test_custom_backend.py new file mode 100644 index 000000000..106b91ada --- /dev/null +++ b/tests/unit/torch/quantization/test_custom_backend.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for registering and using a custom quantization backend via model quantize(). + +This test uses a dummy backend and does NOT rely on PSX formats. It quantizes a test +model with a custom config that enables only the output quantizer to ensure the backend +is invoked and the model output is shifted by a known offset. +""" + +import torch + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.nn import register_quant_backend, unregister_quant_backend + + +def test_custom_backend_via_quantize(): + # Define and register a simple dummy backend that adds a constant to inputs + def dummy_backend(inputs: torch.Tensor, tq) -> torch.Tensor: + extra = getattr(tq, "backend_extra_args", None) or {} + offset = extra.get("offset", 1.0) + return inputs + offset + + register_quant_backend("dummy_backend", dummy_backend) + + model = torch.nn.Linear(16, 16, bias=False) + + cfg = { + "quant_cfg": { + "*weight_quantizer": { + "enable": True, + "num_bits": 8, + "axis": None, + "backend": "dummy_backend", + "backend_extra_args": {"offset": 2.5}, + }, + "default": {"enable": False}, + }, + "algorithm": "max", + } + + inputs = torch.randn(1, 16) + + def forward_loop(m): + m(inputs) + + mtq.quantize(model, cfg, forward_loop=forward_loop) + output_test = model(inputs) + + assert torch.allclose(output_test, inputs @ (model.weight.T + 2.5)) + + # Unregister the backend to avoid impacting other tests + unregister_quant_backend("dummy_backend") From 1b2a57db873c26227da50f16a4b389601369d688 Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 17 Oct 2025 18:11:22 +0000 Subject: [PATCH 2/3] updated CHANGELOG Signed-off-by: realAsma --- CHANGELOG.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 065913464..b6d9bbe9b 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,9 +8,10 @@ Model Optimizer Changelog (Linux) **New Features** -- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``. - Add LoRA mode support for MCore in a new peft submodule: ``modelopt.torch.peft.update_model(model, LORA_CFG)``. - Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details. +- Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend `` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. +- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``. - Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` if no dataset is specified. - Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration. From b119fd187c3262448d9bb526e7eb21bc01db9fab Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 17 Oct 2025 22:23:55 +0000 Subject: [PATCH 3/3] fixed minor tests Signed-off-by: realAsma --- .../quantization/test_tensor_quant_cuda.py | 17 ----------------- .../quantization/test_tensor_quantizer_cuda.py | 8 ++++++++ 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 9519824bb..88e1f3b47 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -30,14 +30,6 @@ class TestFakeTensorQuantCuda(FakeTensorQuantTester): device = "cuda" - def test_non_current_gpu(self, need_2_gpus): - device = torch.cuda.device_count() - 1 - assert torch.cuda.current_device() != device - x = torch.randn(3, 4).cuda(device) - quant_x = tensor_quant.fake_tensor_quant(x, torch.max(torch.abs(x)), None) - quant_x_ref = quant(x, torch.max(torch.abs(x)), fake=True) - assert torch.allclose(quant_x, quant_x_ref) - class TestCudaExt: @pytest.mark.parametrize("num_bits", [3, 4, 5, 7, 8, 11]) @@ -145,15 +137,6 @@ def test_backward(self, device): loss.backward() assert torch.allclose(quant_x.grad, x.grad) - def test_non_current_gpu(self, need_2_gpus): - torch.cuda.set_device(0) - device = torch.cuda.device_count() - 1 - x = torch.randn(3, 4).cuda() - quant_x_ref = tensor_quant.fp8_eager(x, torch.tensor(448.0, device=x.device)) - x = x.cuda(device) - quant_x = tensor_quant.scaled_e4m3(x, None, None, 4, 3) - assert torch.allclose(quant_x.cuda(), quant_x_ref) - @pytest.mark.parametrize("axis", [0, 1, 2]) def test_e4m3_per_channel(self, axis): x = torch.randn(4, 4, 4, dtype=torch.float32).cuda() diff --git a/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py b/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py index 2ca68df23..a2c4f3190 100644 --- a/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py @@ -55,6 +55,14 @@ def test_e4m3(self, E, M, axis): # noqa: N803 ref = tensor_quant.scaled_e4m3(x, e4m3_quantizer._get_amax(x), None, E, M) assert torch.allclose(e4m3_x, ref) + def test_non_current_gpu(self, need_2_gpus): + x = torch.randn(3, 4) + e4m3_desc = QuantizerAttributeConfig(num_bits=(4, 3), axis=None) + quantizer = tensor_quantizer.TensorQuantizer(e4m3_desc).cuda() + xq_ref = quantizer(x.to("cuda:0")) + xq_test = quantizer(x.to("cuda:1")) + assert torch.allclose(xq_ref, xq_test.to("cuda:0")) + @pytest.mark.skipif(get_cuda_ext_mx() is None, reason="cuda_ext_mx is not available") class TestTensorQuantizerfp4: