Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@

import onnx
import torch
from torch.onnx import _type_utils, symbolic_helper
from torch.onnx import symbolic_helper
from torch.onnx import symbolic_helper as sym_help
from torch.onnx._internal import jit_utils
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
Expand Down Expand Up @@ -290,6 +290,11 @@ def scaled_dot_product_attention(
enable_gqa: bool = False,
):
"""Perform scaled dot product attention."""
if hasattr(torch.onnx, "_type_utils"):
from torch.onnx import _type_utils
else:
from torch.onnx._internal.torchscript_exporter import _type_utils

assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
"is_causal and attn_mask cannot be set at the same time"
)
Expand Down Expand Up @@ -393,6 +398,11 @@ def export_fp8_mha(
"""
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask

if hasattr(torch.onnx, "_type_utils"):
from torch.onnx import _type_utils
else:
from torch.onnx._internal.torchscript_exporter import _type_utils

# Pass all arguments, including x, to the custom ONNX operator
assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
"is_causal and attn_mask cannot be set at the same time"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import torch.nn.functional as F
from packaging.version import Version
from torch import nn
from torch.onnx._globals import GLOBALS

from modelopt.torch.utils import standardize_constructor_args
from modelopt.torch.utils.distributed import DistributedProcessGroup
Expand Down Expand Up @@ -880,6 +879,11 @@ def forward(self, inputs):
Returns:
outputs: A Tensor of type output_dtype
"""
if hasattr(torch.onnx, "_globals"):
from torch.onnx._globals import GLOBALS
else:
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS

Comment on lines +881 to +885
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Make ONNX GLOBALS import robust (avoid hasattr; gracefully handle older Torch).

hasattr(torch.onnx, "_globals") can be false‑positive/negative depending on lazy imports; also, neither path may exist on some Torch versions, which would currently raise ImportError on every forward even when ONNX isn’t used. Use try/fallback and provide a safe sentinel.

Apply this diff:

-        if hasattr(torch.onnx, "_globals"):
-            from torch.onnx._globals import GLOBALS
-        else:
-            from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
+        try:
+            from torch.onnx._globals import GLOBALS
+        except Exception:
+            try:
+                from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
+            except Exception:
+                # Fallback for Torch versions lacking both paths; skip ONNX readiness checks.
+                class GLOBALS:  # noqa: N801
+                    in_onnx_export = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(torch.onnx, "_globals"):
from torch.onnx._globals import GLOBALS
else:
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
try:
from torch.onnx._globals import GLOBALS
except Exception:
try:
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
except Exception:
# Fallback for Torch versions lacking both paths; skip ONNX readiness checks.
class GLOBALS: # noqa: N801
in_onnx_export = False
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 881
to 885, the current hasattr-based import for ONNX GLOBALS is brittle and may
raise ImportError on some Torch versions; replace it with a try/except
ImportError sequence: try to import GLOBALS from torch.onnx._globals, if that
fails try torch.onnx._internal.torchscript_exporter._globals, and if both fail
set GLOBALS to a safe sentinel (e.g., None or an object) so forward passes won’t
error when ONNX isn’t used.

if DTensor is not None and isinstance(inputs, DTensor):
# TensorQuantizer only handles regular non-DTensor inputs
device_mesh, placements = inputs.device_mesh, inputs.placements
Expand Down
Loading