diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index bf801646b..eeee8b998 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -59,7 +59,49 @@ if TYPE_CHECKING: from collections.abc import Callable -__all__ = ["SequentialQuantizer", "TensorQuantizer"] +__all__ = [ + "SequentialQuantizer", + "TensorQuantizer", + "get_fake_quant_backend", + "register_fake_quant_backend", + "set_fake_quant_backend", + "unregister_fake_quant_backend", +] + + +# Minimal backend registry for fake-quantization dispatch in TensorQuantizer._fake_quantize +# A backend is a callable with signature: fn(inputs: torch.Tensor, quantizer: TensorQuantizer) -> torch.Tensor +_FAKE_QUANT_BACKEND_REGISTRY: dict[str, Any] = {"modelopt": None} +_CURRENT_FAKE_QUANT_BACKEND: str = "modelopt" + + +def register_fake_quant_backend(name: str, fn): + """Register a fake-quantization backend under a name. + + The backend callable must accept (inputs: torch.Tensor, quantizer: TensorQuantizer) and return a tensor. + """ + _FAKE_QUANT_BACKEND_REGISTRY[name] = fn + + +def set_fake_quant_backend(name: str): + """Set the active fake-quantization backend by name. Use "modelopt" for the built-in logic.""" + if name not in _FAKE_QUANT_BACKEND_REGISTRY: + raise ValueError(f"Unknown fake quant backend: {name}") + global _CURRENT_FAKE_QUANT_BACKEND + _CURRENT_FAKE_QUANT_BACKEND = name + + +def get_fake_quant_backend() -> str: + """Return the currently active fake-quantization backend name.""" + return _CURRENT_FAKE_QUANT_BACKEND + + +def unregister_fake_quant_backend(name: str): + """Unregister a previously registered fake-quantization backend.""" + # Don't allow removing the default backend name + if name == "modelopt": + return + _FAKE_QUANT_BACKEND_REGISTRY.pop(name, None) class TensorQuantizer(nn.Module): @@ -621,6 +663,12 @@ def _real_quantize(self, inputs): def _fake_quantize(self, inputs): """Fake quantization.""" + # Backend dispatch: use registered backend if not default + backend_name = get_fake_quant_backend() + backend_fn = _FAKE_QUANT_BACKEND_REGISTRY.get(backend_name) + if backend_name != "modelopt" and backend_fn is not None: + return backend_fn(inputs, self) + amax = None if not self.is_mx_format: amax = self._get_amax(inputs)