Skip to content

Commit fda052f

Browse files
committed
Addressing comments
Signed-off-by: Riyad Islam <[email protected]>
1 parent e68721d commit fda052f

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

modelopt/torch/quantization/export_onnx.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,6 @@
107107
import onnx
108108
import torch
109109
from packaging.version import Version
110-
111-
if Version(torch.__version__) >= Version("2.9.0"):
112-
from torch.onnx._internal.torchscript_exporter import _type_utils
113-
else:
114-
from torch.onnx import _type_utils
115-
116110
from torch.onnx import symbolic_helper
117111
from torch.onnx import symbolic_helper as sym_help
118112
from torch.onnx._internal import jit_utils
@@ -297,6 +291,11 @@ def scaled_dot_product_attention(
297291
enable_gqa: bool = False,
298292
):
299293
"""Perform scaled dot product attention."""
294+
if Version(torch.__version__) > Version("2.8.0"):
295+
from torch.onnx._internal.torchscript_exporter import _type_utils
296+
else:
297+
from torch.onnx import _type_utils
298+
300299
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
301300
"is_causal and attn_mask cannot be set at the same time"
302301
)
@@ -400,6 +399,11 @@ def export_fp8_mha(
400399
"""
401400
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
402401

402+
if Version(torch.__version__) > Version("2.8.0"):
403+
from torch.onnx._internal.torchscript_exporter import _type_utils
404+
else:
405+
from torch.onnx import _type_utils
406+
403407
# Pass all arguments, including x, to the custom ONNX operator
404408
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
405409
"is_causal and attn_mask cannot be set at the same time"

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@
3232
from packaging.version import Version
3333
from torch import nn
3434

35-
if Version(torch.__version__) >= Version("2.9.0"):
36-
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
37-
else:
38-
from torch.onnx._globals import GLOBALS
39-
40-
4135
from modelopt.torch.utils import standardize_constructor_args
4236
from modelopt.torch.utils.distributed import DistributedProcessGroup
4337

@@ -885,6 +879,11 @@ def forward(self, inputs):
885879
Returns:
886880
outputs: A Tensor of type output_dtype
887881
"""
882+
if Version(torch.__version__) > Version("2.8.0"):
883+
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
884+
else:
885+
from torch.onnx._globals import GLOBALS
886+
888887
if DTensor is not None and isinstance(inputs, DTensor):
889888
# TensorQuantizer only handles regular non-DTensor inputs
890889
device_mesh, placements = inputs.device_mesh, inputs.placements

0 commit comments

Comments
 (0)