Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add support for PyTorch Geometric quantization.
- Add per tensor and per channel MSE calibrator support.
- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
- Add support for Transformer Engine quantization for Megatron Core models.

**Documentation**

Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from modelopt.torch.opt.utils import forward_with_reshard
from modelopt.torch.quantization.config import QuantizeConfig
from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
from modelopt.torch.utils import atomic_print

from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe
from .config import QuantizeAlgoCfgType
Expand Down Expand Up @@ -506,6 +507,7 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable):
set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": True})


@atomic_print
def print_quant_summary(model: nn.Module):
"""Print summary of all quantizer modules in the model."""
count = 0
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/plugins/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class _QuantFunctionalMixin(QuantModule):
def functionals_to_replace(self) -> Iterator[tuple[ModuleType, str, Callable]]:
return (
(package, func_name, quantized_func)
for package, func_name, quantized_func in self._functionals_to_replace
for package, func_name, quantized_func in getattr(self, "_functionals_to_replace", [])
if hasattr(package, func_name)
)

Expand Down
22 changes: 21 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelGroupedLinear,
TERowParallelLinear,
)

from .transformer_engine import _QuantTEGroupedLinear
from .transformer_engine import _QuantTEGroupedLinear, _QuantTELayerNormLinear, _QuantTELinear

HAS_TE = True
except ImportError:
Expand Down Expand Up @@ -549,6 +552,23 @@ def sync_moe_local_experts_amax(self):


if HAS_TE:

@QuantModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"})
class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear):
pass

@QuantModuleRegistry.register({TEColumnParallelLinear: "te_mcore_ColumnParallelLinear"})
class _QuantTEMCoreColumnParallelLinear(_QuantTELinear, _MegatronColumnParallelLinear):
pass

@QuantModuleRegistry.register(
{TELayerNormColumnParallelLinear: "te_mcore_LayerNormColumnParallelLinear"}
)
class _QuantTELayerNormColumnParallelLinear(
_QuantTELayerNormLinear, _MegatronColumnParallelLinear
):
pass

# Quantized subclasses to support TEGroupedMLP quantization
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
Expand Down
212 changes: 190 additions & 22 deletions modelopt/torch/quantization/plugins/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,87 @@

"""Support quantization for Transformer Engine layers."""

import warnings

import torch
import transformer_engine as te
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
import transformer_engine.pytorch.module.layernorm_linear as te_layernorm_linear
import transformer_engine.pytorch.module.linear as te_linear
from packaging.version import Version

from modelopt.torch.quantization.utils import replace_function

from ..nn import QuantModuleRegistry
from .custom import _ParallelLinear

_TE_VERSION = Version(te.__version__)


def _assert_te_fp8_enabled():
"""Check if Transformer Engine FP8 autocast is enabled and raise error if so."""
try:
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager

if FP8GlobalStateManager.is_fp8_enabled():
raise RuntimeError(
"Transformer Engine FP8 training (fp8_autocast) is enabled, which conflicts with "
"ModelOpt quantization. Please disable TE FP8 autocast when using ModelOpt "
"quantization, or use ModelOpt's FP8 quantization instead."
)
except ImportError:
pass # Older TE versions may not have this API


@QuantModuleRegistry.register({te.pytorch.Linear: "te_Linear"})
class _QuantTELinear(_ParallelLinear):
_functionals_to_replace = [
(
te_linear._Linear,
"apply" if torch.is_grad_enabled() else "forward",
),
]
@property
def _functionals_to_replace(self):
return (
[(te_linear._Linear, "apply")]
if torch.is_grad_enabled()
else [(te_linear._Linear, "forward")]
)

@_functionals_to_replace.setter
def _functionals_to_replace(self, value):
self._functionals_to_replace = value

def _setup(self):
super()._setup()
if getattr(self, "fuse_wgrad_accumulation", False):
warnings.warn(
"fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
"Setting fuse_wgrad_accumulation to False."
)
self.fuse_wgrad_accumulation = False

@staticmethod
def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
"""Quantized version specifically for TE with weight first, then input."""
if te.__version__ >= "2.0":
weight, inputs = args[0], args[1]
remaining_args = args[2:]
_assert_te_fp8_enabled()
if Version("2.0") <= _TE_VERSION:
idx = 1 if func_name == "_forward" else 0
weight, inputs = args[idx], args[idx + 1]
remaining_args = args[idx + 2 :]
weight = self.weight_quantizer(weight)
inputs = self.input_quantizer(inputs)
new_args = (weight, inputs, *remaining_args)
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
output = getattr(package, func_name)(
self.weight_quantizer(weight),
self.input_quantizer(inputs),
*remaining_args,
*new_args,
**kwargs,
)
else:
weight, weight_fp8, inputs = args[0], args[1], args[2]
remaining_args = args[3:]
idx = 1 if func_name == "_forward" else 0
weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
remaining_args = args[idx + 3 :]
weight = self.weight_quantizer(weight)
inputs = self.input_quantizer(inputs)
new_args = (weight, weight_fp8, inputs, *remaining_args)
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
output = getattr(package, func_name)(
self.weight_quantizer(weight),
weight_fp8,
self.input_quantizer(inputs),
*remaining_args,
*new_args,
**kwargs,
)
return self.output_quantizer(output)
Expand All @@ -64,10 +107,17 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
# Register the public te.pytorch.GroupedLinear class
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
class _QuantTEGroupedLinear(_ParallelLinear):
_functionals_to_replace = [
(te_grouped_linear._GroupedLinear, "forward"),
(te_grouped_linear._GroupedLinear, "apply"),
]
@property
def _functionals_to_replace(self):
return (
[(te_grouped_linear._GroupedLinear, "apply")]
if torch.is_grad_enabled()
else [(te_grouped_linear._GroupedLinear, "forward")]
)

@_functionals_to_replace.setter
def _functionals_to_replace(self, value):
self._functionals_to_replace = value

def _setup(self):
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
Expand All @@ -93,6 +143,7 @@ def modelopt_post_restore(self, prefix: str = ""):

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
_assert_te_fp8_enabled()
idx = 1 if func_name == "_forward" else 0
inp = args[idx]
num_gemms = len(args[idx + 1])
Expand All @@ -116,3 +167,120 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):

# Override the quantized linear function
_quantized_linear_fn = te_grouped_quantized_linear_fn


class _QuantLayerNormLinearFunc(torch.autograd.Function):
"""Patched version of _LayerNormLinear to quantize the input to the GEMM operation."""

@staticmethod
def _get_original_gemm():
if Version("2.0") <= _TE_VERSION:
return te_layernorm_linear.general_gemm
else:
return te_layernorm_linear.tex.gemm

@staticmethod
def _gemm_replace_args():
if Version("2.0") <= _TE_VERSION:
return (te_layernorm_linear, "general_gemm")
else:
return (te_layernorm_linear.tex, "gemm")

@staticmethod
def forward(ctx, inp, ln_weight, ln_bias, weight, *args, **kwargs):
input_quantizer, weight_quantizer = _QuantLayerNormLinearFunc.modelopt_quantizers

qweight = weight_quantizer(weight)
qweight.requires_grad = weight.requires_grad
if ctx is not None:
# We need to recompute the quantized input for the backward pass, so we save the input_quantizer
ctx.modelopt_input_quantizer = input_quantizer

original_gemm = _QuantLayerNormLinearFunc._get_original_gemm()

def _patched_general_gemm(weight, input, *gemm_args, **gemm_kwargs):
qinput = input_quantizer(input)
return original_gemm(weight, qinput, *gemm_args, **gemm_kwargs)

with replace_function(
*_QuantLayerNormLinearFunc._gemm_replace_args(),
_patched_general_gemm, # type: ignore[call-arg]
):
outputs = te_layernorm_linear._og_LayerNormLinear.forward(
ctx, inp, ln_weight, ln_bias, qweight, *args, **kwargs
)
return outputs

# TODO: Support non-pass-through backward behavior for activation quantization
@staticmethod
def backward(ctx, *grad_outputs):
"""Backward pass for _QuantLayerNormLinearFunc functional.

The backward pass input and weight gradient estimation uses straight through estimator (STE).
We should add support for advanced gradient estimation techniques like STE with clipping.
However this is a low priority item.
"""
gemm_call_counter = {"count": 0}

original_gemm = _QuantLayerNormLinearFunc._get_original_gemm()

def _patched_general_gemm(a, b, *gemm_args, **gemm_kwargs):
# The first time, gemm is used for dgrad calculation
# dgrad GEMM; dx = dy * qw; Called as gemm(qw, dy, ...)
if gemm_call_counter["count"] == 0:
gemm_call_counter["count"] += 1
return original_gemm(a, b, *gemm_args, **gemm_kwargs)

# The second time, gemm is used for wgrad calculation
# wgrad GEMM; dqw = dy^T * x; Called as gemm(x, dy, ..);

# x should be quantized input (qinput) for the backward pass as per chain rule,
# but gemm is called with the unquantized input (a)
# So lets first get the quantized input (qinput) and then call the gemm
qinput = ctx.modelopt_input_quantizer(a)
return original_gemm(qinput, b, *gemm_args, **gemm_kwargs)

with replace_function(
*_QuantLayerNormLinearFunc._gemm_replace_args(),
_patched_general_gemm, # type: ignore[call-arg]
):
# During backward, the patch does not exist; autograd will automatically use
# _QuantLayerNormLinearFunc.backward
outputs = te_layernorm_linear._LayerNormLinear.backward(ctx, *grad_outputs)

delattr(ctx, "modelopt_input_quantizer")
return outputs


@QuantModuleRegistry.register({te.pytorch.LayerNormLinear: "te_LayerNormLinear"})
class _QuantTELayerNormLinear(_ParallelLinear):
_functionals_to_replace = []

def _setup(self):
super()._setup()
if getattr(self, "fuse_wgrad_accumulation", False):
warnings.warn(
"fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
"Setting fuse_wgrad_accumulation to False."
)
self.fuse_wgrad_accumulation = False

def forward(self, *args, **kwargs):
"""Call ModelOpt patch for _LayerNormLinear functional."""
_assert_te_fp8_enabled()
# This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
_QuantLayerNormLinearFunc.modelopt_quantizers = (
self.input_quantizer,
self.weight_quantizer,
)
with replace_function(
te_layernorm_linear,
"_LayerNormLinear",
_QuantLayerNormLinearFunc,
"_og_LayerNormLinear",
):
outputs = super().forward(*args, **kwargs)
delattr(_QuantLayerNormLinearFunc, "modelopt_quantizers")
if isinstance(outputs, tuple):
return (self.output_quantizer(outputs[0]), *outputs[1:])
return self.output_quantizer(outputs)
8 changes: 5 additions & 3 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,16 @@ def disable_lora_quantizers_in_config(config, layers):


@contextmanager
def replace_function(package, name, new_func):
def replace_function(package, name, new_func, og_func_cache_name=None):
"""Replace a function with a new one within a context."""
if og_func_cache_name is None:
og_func_cache_name = "_" + name
old_func = getattr(package, name)
setattr(package, name, new_func)
setattr(package, "_" + name, old_func)
setattr(package, og_func_cache_name, old_func)
yield
setattr(package, name, old_func)
delattr(package, "_" + name)
delattr(package, og_func_cache_name)


@contextmanager
Expand Down
Loading