Skip to content

Commit 009865b

Browse files
committed
Addressing comments
Signed-off-by: Riyad Islam <[email protected]>
1 parent e68721d commit 009865b

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

modelopt/torch/quantization/export_onnx.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,6 @@
106106

107107
import onnx
108108
import torch
109-
from packaging.version import Version
110-
111-
if Version(torch.__version__) >= Version("2.9.0"):
112-
from torch.onnx._internal.torchscript_exporter import _type_utils
113-
else:
114-
from torch.onnx import _type_utils
115-
116109
from torch.onnx import symbolic_helper
117110
from torch.onnx import symbolic_helper as sym_help
118111
from torch.onnx._internal import jit_utils
@@ -297,6 +290,11 @@ def scaled_dot_product_attention(
297290
enable_gqa: bool = False,
298291
):
299292
"""Perform scaled dot product attention."""
293+
if hasattr(torch.onnx, "_type_utils"):
294+
from torch.onnx import _type_utils
295+
else:
296+
from torch.onnx._internal.torchscript_exporter import _type_utils
297+
300298
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
301299
"is_causal and attn_mask cannot be set at the same time"
302300
)
@@ -400,6 +398,11 @@ def export_fp8_mha(
400398
"""
401399
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
402400

401+
if hasattr(torch.onnx, "_type_utils"):
402+
from torch.onnx import _type_utils
403+
else:
404+
from torch.onnx._internal.torchscript_exporter import _type_utils
405+
403406
# Pass all arguments, including x, to the custom ONNX operator
404407
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
405408
"is_causal and attn_mask cannot be set at the same time"

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@
3232
from packaging.version import Version
3333
from torch import nn
3434

35-
if Version(torch.__version__) >= Version("2.9.0"):
36-
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
37-
else:
38-
from torch.onnx._globals import GLOBALS
39-
40-
4135
from modelopt.torch.utils import standardize_constructor_args
4236
from modelopt.torch.utils.distributed import DistributedProcessGroup
4337

@@ -885,6 +879,11 @@ def forward(self, inputs):
885879
Returns:
886880
outputs: A Tensor of type output_dtype
887881
"""
882+
if hasattr(torch.onnx, "_globals"):
883+
from torch.onnx._globals import GLOBALS
884+
else:
885+
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
886+
888887
if DTensor is not None and isinstance(inputs, DTensor):
889888
# TensorQuantizer only handles regular non-DTensor inputs
890889
device_mesh, placements = inputs.device_mesh, inputs.placements

0 commit comments

Comments
 (0)