Skip to content

Commit 5f54471

Browse files
committed
Addressing comments
Signed-off-by: Riyad Islam <[email protected]>
1 parent e68721d commit 5f54471

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

modelopt/torch/quantization/export_onnx.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,6 @@
107107
import onnx
108108
import torch
109109
from packaging.version import Version
110-
111-
if Version(torch.__version__) >= Version("2.9.0"):
112-
from torch.onnx._internal.torchscript_exporter import _type_utils
113-
else:
114-
from torch.onnx import _type_utils
115-
116110
from torch.onnx import symbolic_helper
117111
from torch.onnx import symbolic_helper as sym_help
118112
from torch.onnx._internal import jit_utils
@@ -297,6 +291,11 @@ def scaled_dot_product_attention(
297291
enable_gqa: bool = False,
298292
):
299293
"""Perform scaled dot product attention."""
294+
if Version(torch.__version__) > Version("2.8.0"):
295+
from torch.onnx._internal.torchscript_exporter import _type_utils
296+
else:
297+
from torch.onnx import _type_utils
298+
300299
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
301300
"is_causal and attn_mask cannot be set at the same time"
302301
)
@@ -400,6 +399,11 @@ def export_fp8_mha(
400399
"""
401400
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
402401

402+
if Version(torch.__version__) > Version("2.8.0"):
403+
from torch.onnx._internal.torchscript_exporter import _type_utils
404+
else:
405+
from torch.onnx import _type_utils
406+
403407
# Pass all arguments, including x, to the custom ONNX operator
404408
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
405409
"is_causal and attn_mask cannot be set at the same time"

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from packaging.version import Version
3333
from torch import nn
3434

35-
if Version(torch.__version__) >= Version("2.9.0"):
35+
if Version(torch.__version__) > Version("2.8.0"):
3636
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
3737
else:
3838
from torch.onnx._globals import GLOBALS

0 commit comments

Comments
 (0)