diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index 984dba0b..fe9bd927 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -106,7 +106,7 @@ import onnx import torch -from torch.onnx import _type_utils, symbolic_helper +from torch.onnx import symbolic_helper from torch.onnx import symbolic_helper as sym_help from torch.onnx._internal import jit_utils from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask @@ -290,6 +290,11 @@ def scaled_dot_product_attention( enable_gqa: bool = False, ): """Perform scaled dot product attention.""" + if hasattr(torch.onnx, "_type_utils"): + from torch.onnx import _type_utils + else: + from torch.onnx._internal.torchscript_exporter import _type_utils + assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( "is_causal and attn_mask cannot be set at the same time" ) @@ -393,6 +398,11 @@ def export_fp8_mha( """ from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + if hasattr(torch.onnx, "_type_utils"): + from torch.onnx import _type_utils + else: + from torch.onnx._internal.torchscript_exporter import _type_utils + # Pass all arguments, including x, to the custom ONNX operator assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( "is_causal and attn_mask cannot be set at the same time" diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 9846f355..0635b7c9 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -30,7 +30,6 @@ import torch.nn.functional as F from torch import nn -from torch.onnx._globals import GLOBALS from modelopt.torch.utils import standardize_constructor_args from modelopt.torch.utils.distributed import DistributedProcessGroup @@ -879,6 +878,11 @@ def forward(self, inputs): Returns: outputs: A Tensor of type output_dtype """ + if hasattr(torch.onnx, "_globals"): + from torch.onnx._globals import GLOBALS + else: + from torch.onnx._internal.torchscript_exporter._globals import GLOBALS + if DTensor is not None and isinstance(inputs, DTensor): # TensorQuantizer only handles regular non-DTensor inputs device_mesh, placements = inputs.device_mesh, inputs.placements