Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,14 @@

import onnx
import torch
from torch.onnx import _type_utils, symbolic_helper
from packaging.version import Version

if Version(torch.__version__) >= Version("2.9.0"):
from torch.onnx._internal.torchscript_exporter import _type_utils
else:
from torch.onnx import _type_utils

Copy link

@coderabbitai coderabbitai bot Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Prefer feature-detection over version checks for _type_utils (drop packaging dep).

Version-guarding is brittle (nightlies/alphas may differ). Try known locations in order and fall back, avoiding a hard dependency on packaging.

Apply:

-from packaging.version import Version
-
-if Version(torch.__version__) >= Version("2.9.0"):
-    from torch.onnx._internal.torchscript_exporter import _type_utils
-else:
-    from torch.onnx import _type_utils
+try:
+    # Newer PyTorch layouts first
+    from torch.onnx._internal.torchscript_exporter import _type_utils  # type: ignore[attr-defined]
+except Exception:
+    try:
+        from torch.onnx._internal.exporter import _type_utils  # type: ignore[attr-defined]
+    except Exception:
+        # Older layout
+        from torch.onnx import _type_utils
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from packaging.version import Version
if Version(torch.__version__) >= Version("2.9.0"):
from torch.onnx._internal.torchscript_exporter import _type_utils
else:
from torch.onnx import _type_utils
try:
# Newer PyTorch layouts first
from torch.onnx._internal.torchscript_exporter import _type_utils # type: ignore[attr-defined]
except Exception:
try:
from torch.onnx._internal.exporter import _type_utils # type: ignore[attr-defined]
except Exception:
# Older layout
from torch.onnx import _type_utils
🤖 Prompt for AI Agents
In modelopt/torch/quantization/export_onnx.py around lines 109 to 115, replace
the packaging.version based import with feature-detection: attempt to import
_type_utils from torch.onnx._internal.torchscript_exporter in a try/except and
on ImportError fall back to importing _type_utils from torch.onnx; remove the
packaging.version import and any Version checks so the code tries known
locations in order and gracefully falls back for nightlies/alphas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevalmorabia97 which approach is better?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If torch is installed from source before official 2.9 release (e.g. in DLFW Pytorch containers), it will have version like 2.9.0a0...nv which using version check would show as <2.9.0 even though it has 2.9.0 changes. Lets change the version check to >2.8 instead of >=2.9

from torch.onnx import symbolic_helper
from torch.onnx import symbolic_helper as sym_help
from torch.onnx._internal import jit_utils
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
import torch.nn.functional as F
from packaging.version import Version
from torch import nn
from torch.onnx._globals import GLOBALS

if Version(torch.__version__) >= Version("2.9.0"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer moving this import down in the function where it is used since we are importing from torch.onnx._internal which is a private function and may again change in the future potentially causing import to break

from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
else:
from torch.onnx._globals import GLOBALS


from modelopt.torch.utils import standardize_constructor_args
from modelopt.torch.utils.distributed import DistributedProcessGroup
Expand Down
Loading