Skip to content

Commit 682bf6d

Browse files
authored
import fix for torch 2.9 (#315)
Signed-off-by: Riyad Islam <[email protected]>
1 parent 3524732 commit 682bf6d

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

modelopt/torch/quantization/export_onnx.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106

107107
import onnx
108108
import torch
109-
from torch.onnx import _type_utils, symbolic_helper
109+
from torch.onnx import symbolic_helper
110110
from torch.onnx import symbolic_helper as sym_help
111111
from torch.onnx._internal import jit_utils
112112
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
@@ -290,6 +290,11 @@ def scaled_dot_product_attention(
290290
enable_gqa: bool = False,
291291
):
292292
"""Perform scaled dot product attention."""
293+
if hasattr(torch.onnx, "_type_utils"):
294+
from torch.onnx import _type_utils
295+
else:
296+
from torch.onnx._internal.torchscript_exporter import _type_utils
297+
293298
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
294299
"is_causal and attn_mask cannot be set at the same time"
295300
)
@@ -393,6 +398,11 @@ def export_fp8_mha(
393398
"""
394399
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
395400

401+
if hasattr(torch.onnx, "_type_utils"):
402+
from torch.onnx import _type_utils
403+
else:
404+
from torch.onnx._internal.torchscript_exporter import _type_utils
405+
396406
# Pass all arguments, including x, to the custom ONNX operator
397407
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
398408
"is_causal and attn_mask cannot be set at the same time"

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
import torch.nn.functional as F
3232
from torch import nn
33-
from torch.onnx._globals import GLOBALS
3433

3534
from modelopt.torch.utils import standardize_constructor_args
3635
from modelopt.torch.utils.distributed import DistributedProcessGroup
@@ -879,6 +878,11 @@ def forward(self, inputs):
879878
Returns:
880879
outputs: A Tensor of type output_dtype
881880
"""
881+
if hasattr(torch.onnx, "_globals"):
882+
from torch.onnx._globals import GLOBALS
883+
else:
884+
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
885+
882886
if DTensor is not None and isinstance(inputs, DTensor):
883887
# TensorQuantizer only handles regular non-DTensor inputs
884888
device_mesh, placements = inputs.device_mesh, inputs.placements

0 commit comments

Comments
 (0)