Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 65 additions & 50 deletions modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,18 @@
"""Utility to export a quantized torch model to quantized ONNX."""

import contextlib
from typing import TYPE_CHECKING

import onnx
import torch
from torch.onnx import symbolic_helper
from torch.onnx import symbolic_helper as sym_help
from torch.onnx._internal import jit_utils
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask

if TYPE_CHECKING:
if hasattr(torch.onnx._internal, "jit_utils"):
from torch.onnx._internal.jit_utils import GraphContext
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext

onnx_dtype_map = {
"BFloat16": onnx.TensorProto.BFLOAT16,
Expand All @@ -125,7 +130,7 @@


def export_int8(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
amax: torch.Tensor,
num_bits: int,
Expand Down Expand Up @@ -184,7 +189,7 @@ def export_int8(


def export_int4(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
amax: torch.Tensor,
num_bits: int,
Expand All @@ -208,7 +213,7 @@ def export_int4(


def _fp8_quantize(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
scale_inv: float,
trt_high_precision_dtype: str,
Expand Down Expand Up @@ -236,7 +241,7 @@ def _fp8_quantize(


def _fp8_dequantize(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
scale_inv: float,
trt_high_precision_dtype: str,
Expand All @@ -263,7 +268,7 @@ def _fp8_dequantize(


def export_fp8(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
amax: float,
trt_high_precision_dtype: str | None,
Expand All @@ -279,21 +284,29 @@ def export_fp8(


def scaled_dot_product_attention(
g: jit_utils.GraphContext,
query: torch._C.Value,
key: torch._C.Value,
value: torch._C.Value,
attn_mask: torch._C.Value | None = None,
g: "GraphContext",
query: "torch._C.Value",
key: "torch._C.Value",
value: "torch._C.Value",
attn_mask: "torch._C.Value | None" = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
scale: "torch._C.Value | None" = None,
enable_gqa: bool = False,
):
"""Perform scaled dot product attention."""
if hasattr(torch.onnx, "_type_utils"):
from torch.onnx import _type_utils
else:
from torch.onnx._internal.torchscript_exporter import _type_utils
from torch.onnx._type_utils import JitScalarType
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter import JitScalarType

if hasattr(torch.onnx, "symbolic_opset14"):
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import (
_attention_scale,
_causal_attention_mask,
)

assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
"is_causal and attn_mask cannot be set at the same time"
Expand Down Expand Up @@ -327,22 +340,20 @@ def scaled_dot_product_attention(

if symbolic_helper._is_none(attn_mask):
mul_qk_add = mul_qk
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL:
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
mul_qk_add = g.op("Add", mul_qk, attn_mask)
elif _type_utils.JitScalarType.from_value(attn_mask) in (
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.BFLOAT16,
elif JitScalarType.from_value(attn_mask) in (
JitScalarType.FLOAT,
JitScalarType.HALF,
JitScalarType.BFLOAT16,
):
mul_qk_add = g.op("Add", mul_qk, attn_mask)
else:
raise ValueError(
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
)
raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}")

attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)

Expand All @@ -357,14 +368,14 @@ def scaled_dot_product_attention(


def export_fp8_mha(
g: torch.onnx._internal.jit_utils.GraphContext,
query: torch._C.Value,
key: torch._C.Value,
value: torch._C.Value,
attn_mask: torch._C.Value | None = None,
g: "GraphContext",
query: "torch._C.Value",
key: "torch._C.Value",
value: "torch._C.Value",
attn_mask: "torch._C.Value | None" = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
scale: "torch._C.Value | None" = None,
q_quantized_scale: float = 1.0,
k_quantized_scale: float = 1.0,
v_quantized_scale: float = 1.0,
Expand Down Expand Up @@ -396,12 +407,18 @@ def export_fp8_mha(
|
Cast
"""
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask

if hasattr(torch.onnx, "_type_utils"):
from torch.onnx import _type_utils
else:
from torch.onnx._internal.torchscript_exporter import _type_utils
from torch.onnx._type_utils import JitScalarType
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter import JitScalarType

if hasattr(torch.onnx, "symbolic_opset14"):
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import (
_attention_scale,
_causal_attention_mask,
)

# Pass all arguments, including x, to the custom ONNX operator
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
Expand Down Expand Up @@ -452,22 +469,20 @@ def export_fp8_mha(

if sym_help._is_none(attn_mask):
mul_qk_add = mul_qk
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL:
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
mul_qk_add = g.op("Add", mul_qk, attn_mask)
elif _type_utils.JitScalarType.from_value(attn_mask) in (
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.BFLOAT16,
elif JitScalarType.from_value(attn_mask) in (
JitScalarType.FLOAT,
JitScalarType.HALF,
JitScalarType.BFLOAT16,
):
mul_qk_add = g.op("Add", mul_qk, attn_mask)
else:
raise ValueError(
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
)
raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}")

attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)

Expand Down Expand Up @@ -495,7 +510,7 @@ def export_fp8_mha(


def _fp4_dynamic_quantize(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
scale: float,
trt_high_precision_dtype: str | None,
Expand Down Expand Up @@ -531,7 +546,7 @@ def _fp4_dynamic_quantize(


def _fp4_dequantize(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
scale: float | torch.Value,
trt_high_precision_dtype: str | None,
Expand All @@ -546,7 +561,7 @@ def _fp4_dequantize(


def _fp4_dequantize_2(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
dyn_scale: torch.Value,
block_size: int,
Expand All @@ -557,7 +572,7 @@ def _fp4_dequantize_2(


def _mxfp8_dynamic_quantize(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
block_size: int,
axis: int = -1,
Expand All @@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize(


def _mxfp8_dequantize(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
scale: torch.Value,
block_size: int,
Expand All @@ -593,7 +608,7 @@ def _mxfp8_dequantize(


def export_mxfp8(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Tensor,
onnx_quantizer_type: str,
block_size: int,
Expand All @@ -611,7 +626,7 @@ def export_mxfp8(


def export_fp4(
g: torch.onnx._internal.jit_utils.GraphContext,
g: "GraphContext",
inputs: torch.Value,
block_size: int,
amax: torch.Value,
Expand Down
6 changes: 3 additions & 3 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _get_amax(self, inputs):

def _validate_amax(self, amax):
# Dynamic control flow is not supported by torch dynamo
if not is_torch_export_mode() and not torch._dynamo.is_compiling():
if not is_torch_export_mode() and not torch.compiler.is_compiling():
assert torch.all(amax >= 0) and not torch.any(torch.isinf(amax)), (
f"Got invalid amax: {amax}"
)
Expand Down Expand Up @@ -880,7 +880,7 @@ def forward(self, inputs):
"""
if hasattr(torch.onnx, "_globals"):
from torch.onnx._globals import GLOBALS
else:
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS

if DTensor is not None and isinstance(inputs, DTensor):
Expand Down Expand Up @@ -914,7 +914,7 @@ def forward(self, inputs):

if (
not is_torch_export_mode()
and not torch._dynamo.is_compiling()
and not torch.compiler.is_compiling()
and GLOBALS.in_onnx_export
):
# GLOBALS could break TorchDynamo for some Pytorch versions (i.e., 2.3.0)
Expand Down
23 changes: 13 additions & 10 deletions modelopt/torch/quantization/plugins/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

"""Support quantization of diffusers layers."""

import functools
from collections.abc import Callable, Iterator
from functools import partial
from types import ModuleType
from typing import TYPE_CHECKING

import onnx
import torch
Expand All @@ -27,7 +27,12 @@
from torch.autograd import Function
from torch.nn import functional as F
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration

if TYPE_CHECKING:
if hasattr(torch.onnx._internal, "jit_utils"):
from torch.onnx._internal.jit_utils import GraphContext
else: # torch >= 2.9
from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext

from ..export_onnx import export_fp8_mha
from ..nn import (
Expand All @@ -40,8 +45,6 @@
)
from .custom import _QuantFunctionalMixin

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)

onnx_dtype_map = {
"BFloat16": onnx.TensorProto.BFLOAT16,
"Float": onnx.TensorProto.FLOAT,
Expand Down Expand Up @@ -205,14 +208,14 @@ def forward(
@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b")
def symbolic(
g: jit_utils.GraphContext,
query: torch._C.Value,
key: torch._C.Value,
value: torch._C.Value,
attn_mask: torch._C.Value | None = None,
g: "GraphContext",
query: "torch._C.Value",
key: "torch._C.Value",
value: "torch._C.Value",
attn_mask: "torch._C.Value | None" = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: torch._C.Value | None = None,
scale: "torch._C.Value | None" = None,
q_quantized_scale: float = 1.0,
k_quantized_scale: float = 1.0,
v_quantized_scale: float = 1.0,
Expand Down