Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,49 @@
if TYPE_CHECKING:
from collections.abc import Callable

__all__ = ["SequentialQuantizer", "TensorQuantizer"]
__all__ = [
"SequentialQuantizer",
"TensorQuantizer",
"get_fake_quant_backend",
"register_fake_quant_backend",
"set_fake_quant_backend",
"unregister_fake_quant_backend",
]


# Minimal backend registry for fake-quantization dispatch in TensorQuantizer._fake_quantize
# A backend is a callable with signature: fn(inputs: torch.Tensor, quantizer: TensorQuantizer) -> torch.Tensor
_FAKE_QUANT_BACKEND_REGISTRY: dict[str, Any] = {"modelopt": None}
_CURRENT_FAKE_QUANT_BACKEND: str = "modelopt"


def register_fake_quant_backend(name: str, fn):
"""Register a fake-quantization backend under a name.

The backend callable must accept (inputs: torch.Tensor, quantizer: TensorQuantizer) and return a tensor.
"""
_FAKE_QUANT_BACKEND_REGISTRY[name] = fn


def set_fake_quant_backend(name: str):
"""Set the active fake-quantization backend by name. Use "modelopt" for the built-in logic."""
if name not in _FAKE_QUANT_BACKEND_REGISTRY:
raise ValueError(f"Unknown fake quant backend: {name}")
global _CURRENT_FAKE_QUANT_BACKEND
_CURRENT_FAKE_QUANT_BACKEND = name


def get_fake_quant_backend() -> str:
"""Return the currently active fake-quantization backend name."""
return _CURRENT_FAKE_QUANT_BACKEND


def unregister_fake_quant_backend(name: str):
"""Unregister a previously registered fake-quantization backend."""
# Don't allow removing the default backend name
if name == "modelopt":
return
_FAKE_QUANT_BACKEND_REGISTRY.pop(name, None)


class TensorQuantizer(nn.Module):
Expand Down Expand Up @@ -621,6 +663,12 @@ def _real_quantize(self, inputs):

def _fake_quantize(self, inputs):
"""Fake quantization."""
# Backend dispatch: use registered backend if not default
backend_name = get_fake_quant_backend()
backend_fn = _FAKE_QUANT_BACKEND_REGISTRY.get(backend_name)
if backend_name != "modelopt" and backend_fn is not None:
return backend_fn(inputs, self)

amax = None
if not self.is_mx_format:
amax = self._get_amax(inputs)
Expand Down