Skip to content

Commit e68721d

Browse files
committed
Making conditional import for torch 2.9
Signed-off-by: Riyad Islam <[email protected]>
1 parent c001c2c commit e68721d

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

modelopt/torch/quantization/export_onnx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,14 @@
106106

107107
import onnx
108108
import torch
109-
from torch.onnx import _type_utils, symbolic_helper
109+
from packaging.version import Version
110+
111+
if Version(torch.__version__) >= Version("2.9.0"):
112+
from torch.onnx._internal.torchscript_exporter import _type_utils
113+
else:
114+
from torch.onnx import _type_utils
115+
116+
from torch.onnx import symbolic_helper
110117
from torch.onnx import symbolic_helper as sym_help
111118
from torch.onnx._internal import jit_utils
112119
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
import torch.nn.functional as F
3232
from packaging.version import Version
3333
from torch import nn
34-
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
34+
35+
if Version(torch.__version__) >= Version("2.9.0"):
36+
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
37+
else:
38+
from torch.onnx._globals import GLOBALS
39+
3540

3641
from modelopt.torch.utils import standardize_constructor_args
3742
from modelopt.torch.utils.distributed import DistributedProcessGroup

0 commit comments

Comments
 (0)