diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe9bd927b..e8a38b162 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -103,13 +103,18 @@ """Utility to export a quantized torch model to quantized ONNX.""" import contextlib +from typing import TYPE_CHECKING import onnx import torch from torch.onnx import symbolic_helper from torch.onnx import symbolic_helper as sym_help -from torch.onnx._internal import jit_utils -from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + +if TYPE_CHECKING: + if hasattr(torch.onnx._internal, "jit_utils"): + from torch.onnx._internal.jit_utils import GraphContext + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, @@ -125,7 +130,7 @@ def export_int8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: torch.Tensor, num_bits: int, @@ -184,7 +189,7 @@ def export_int8( def export_int4( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: torch.Tensor, num_bits: int, @@ -208,7 +213,7 @@ def export_int4( def _fp8_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -236,7 +241,7 @@ def _fp8_quantize( def _fp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -263,7 +268,7 @@ def _fp8_dequantize( def export_fp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: float, trt_high_precision_dtype: str | None, @@ -279,21 +284,29 @@ def export_fp8( def scaled_dot_product_attention( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, enable_gqa: bool = False, ): """Perform scaled dot product attention.""" if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + from torch.onnx._type_utils import JitScalarType + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter import JitScalarType + + if hasattr(torch.onnx, "symbolic_opset14"): + from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import ( + _attention_scale, + _causal_attention_mask, + ) assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( "is_causal and attn_mask cannot be set at the same time" @@ -327,22 +340,20 @@ def scaled_dot_product_attention( if symbolic_helper._is_none(attn_mask): mul_qk_add = mul_qk - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: + elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL: # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, + elif JitScalarType.from_value(attn_mask) in ( + JitScalarType.FLOAT, + JitScalarType.HALF, + JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) + raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}") attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -357,14 +368,14 @@ def scaled_dot_product_attention( def export_fp8_mha( - g: torch.onnx._internal.jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float = 1.0, k_quantized_scale: float = 1.0, v_quantized_scale: float = 1.0, @@ -396,12 +407,18 @@ def export_fp8_mha( | Cast """ - from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask - if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + from torch.onnx._type_utils import JitScalarType + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter import JitScalarType + + if hasattr(torch.onnx, "symbolic_opset14"): + from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import ( + _attention_scale, + _causal_attention_mask, + ) # Pass all arguments, including x, to the custom ONNX operator assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( @@ -452,22 +469,20 @@ def export_fp8_mha( if sym_help._is_none(attn_mask): mul_qk_add = mul_qk - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: + elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL: # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, + elif JitScalarType.from_value(attn_mask) in ( + JitScalarType.FLOAT, + JitScalarType.HALF, + JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) + raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}") attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -495,7 +510,7 @@ def export_fp8_mha( def _fp4_dynamic_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: float, trt_high_precision_dtype: str | None, @@ -531,7 +546,7 @@ def _fp4_dynamic_quantize( def _fp4_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: float | torch.Value, trt_high_precision_dtype: str | None, @@ -546,7 +561,7 @@ def _fp4_dequantize( def _fp4_dequantize_2( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, dyn_scale: torch.Value, block_size: int, @@ -557,7 +572,7 @@ def _fp4_dequantize_2( def _mxfp8_dynamic_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, block_size: int, axis: int = -1, @@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize( def _mxfp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: torch.Value, block_size: int, @@ -593,7 +608,7 @@ def _mxfp8_dequantize( def export_mxfp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Tensor, onnx_quantizer_type: str, block_size: int, @@ -611,7 +626,7 @@ def export_mxfp8( def export_fp4( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, block_size: int, amax: torch.Value, diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0635b7c9b..6e431dce9 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -548,7 +548,7 @@ def _get_amax(self, inputs): def _validate_amax(self, amax): # Dynamic control flow is not supported by torch dynamo - if not is_torch_export_mode() and not torch._dynamo.is_compiling(): + if not is_torch_export_mode() and not torch.compiler.is_compiling(): assert torch.all(amax >= 0) and not torch.any(torch.isinf(amax)), ( f"Got invalid amax: {amax}" ) @@ -880,7 +880,7 @@ def forward(self, inputs): """ if hasattr(torch.onnx, "_globals"): from torch.onnx._globals import GLOBALS - else: + else: # torch >= 2.9 from torch.onnx._internal.torchscript_exporter._globals import GLOBALS if DTensor is not None and isinstance(inputs, DTensor): @@ -914,7 +914,7 @@ def forward(self, inputs): if ( not is_torch_export_mode() - and not torch._dynamo.is_compiling() + and not torch.compiler.is_compiling() and GLOBALS.in_onnx_export ): # GLOBALS could break TorchDynamo for some Pytorch versions (i.e., 2.3.0) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 5f1ab5db1..7c018e1bb 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -15,10 +15,10 @@ """Support quantization of diffusers layers.""" -import functools from collections.abc import Callable, Iterator from functools import partial from types import ModuleType +from typing import TYPE_CHECKING import onnx import torch @@ -27,7 +27,12 @@ from torch.autograd import Function from torch.nn import functional as F from torch.onnx import symbolic_helper -from torch.onnx._internal import jit_utils, registration + +if TYPE_CHECKING: + if hasattr(torch.onnx._internal, "jit_utils"): + from torch.onnx._internal.jit_utils import GraphContext + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext from ..export_onnx import export_fp8_mha from ..nn import ( @@ -40,8 +45,6 @@ ) from .custom import _QuantFunctionalMixin -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) - onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, "Float": onnx.TensorProto.FLOAT, @@ -205,14 +208,14 @@ def forward( @staticmethod @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") def symbolic( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float = 1.0, k_quantized_scale: float = 1.0, v_quantized_scale: float = 1.0,