Skip to content

Conversation

i-riyad
Copy link
Contributor

@i-riyad i-riyad commented Sep 9, 2025

What does this PR do?

Type of change: Bug fix

Overview: ONNX models exported with torch.autocast enabled fail when multiple execution providers are passed to the ONNX Runtime inference session. ModelOpt enables autocast by default for INT8/FP8 ONNX export from torch models. This PR enables skipping of autocast if fp32 model is desired from export.

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Additional Information

Summary by CodeRabbit

  • Bug Fixes
    • Fixed precision handling during ONNX export: automatic GPU autocasting is now disabled for FP32-weight models and certain quantized models, preserving intended dtypes. This improves export accuracy and stability.
    • No changes to public APIs or user-facing configuration.

@i-riyad i-riyad requested a review from a team as a code owner September 9, 2025 00:26
@i-riyad i-riyad requested a review from ajrasane September 9, 2025 00:26
Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

The ONNX export utility changes the autocast selection: autocast is disabled (nullcontext) when the model is FP4-quantized, MXFP8-quantized, or when weights_dtype == "fp32"; otherwise torch.autocast("cuda") is used. Public APIs remain unchanged.

Changes

Cohort / File(s) Change summary
Autocast control in ONNX export
modelopt/torch/_deploy/utils/torch_onnx.py
Expanded condition to use nullcontext() when model is FP4-quantized, MXFP8-quantized, or weights_dtype == "fp32"; otherwise use autocast("cuda"). No API signature changes.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant ExportFunc as get_onnx_bytes_and_metadata
    participant Ctx as Context Manager
    participant ONNX as ONNX Export

    Caller->>ExportFunc: request ONNX bytes + metadata
    ExportFunc->>ExportFunc: evaluate is_fp4_quantized / is_mxfp8_quantized / weights_dtype=="fp32"
    alt FP4 or MXFP8 or FP32 weights
        ExportFunc->>Ctx: use nullcontext()
        Note right of Ctx: Autocast disabled (cpu-like path)
    else
        ExportFunc->>Ctx: use autocast("cuda")
        Note right of Ctx: Autocast enabled (CUDA mixed precision)
    end
    ExportFunc->>ONNX: run export under chosen context
    ONNX-->>ExportFunc: ONNX bytes + metadata
    ExportFunc-->>Caller: return results
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title accurately captures the main intent of the changeset by indicating that autocasting is now bypassed during ONNX export when exporting an FP32 model, concisely reflecting the conditional logic update introduced in the PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

Poem

I twitch my ears at types that flow,
FP4, MXFP8 — and FP32 in tow.
Autocast naps when those flags show,
ONNX bakes steady, soft and slow.
A hop, a whisker, off we go! 🐇

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch rislam/autocast-avoid-for-fp32

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
modelopt/torch/_deploy/utils/torch_onnx.py (4)

399-403: Make autocast device-aware and dtype-aligned; avoid CUDA-only assumption.

torch.autocast("cuda") will raise on CPU-only environments and ignores BF16 requests. Use the model’s device and set dtype accordingly.

Apply:

-    autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
+    # Choose a safe device/dtype for autocast
+    try:
+        device_type = next(model.parameters()).device.type
+    except StopIteration:
+        device_type = "cuda" if torch.cuda.is_available() else "cpu"
+    if use_torch_autocast and device_type in ("cuda", "cpu"):
+        dtype = torch.bfloat16 if weights_dtype == "bf16" else torch.float16
+        autocast = torch.autocast(device_type, dtype=dtype)
+    else:
+        autocast = nullcontext()

354-356: Clarify weights_dtype semantics in the docstring.

Explicitly state that FP32 disables autocast and that no upcast occurs automatically.

-        weights_dtype: The dtype of the weights in the onnx model.
+        weights_dtype: Target ONNX weight dtype ("fp32", "fp16", "bf16").
+            When "fp32", torch.autocast is disabled during export to avoid mixed-precision casts.
+            Note: this does not upcast model parameters automatically; ensure your model params are
+            float32 if a pure FP32 graph is desired.

333-336: Avoid mutable default for dynamic_axes; gate insertion only when provided.

Prevents surprises from a shared dict default and keeps kwargs compact when not needed.

-    dynamic_axes: dict = {},
+    dynamic_axes: dict | None = None,
@@
-        if not dynamo_export and Version(torch.__version__) >= Version("2.8"):
-            additional_kwargs["dynamic_axes"] = dynamic_axes
+        if (
+            not dynamo_export
+            and Version(torch.__version__) >= Version("2.8")
+            and dynamic_axes is not None
+        ):
+            additional_kwargs["dynamic_axes"] = dynamic_axes

Also applies to: 435-437


485-491: Optional: warn (or upcast) when weights_dtype="fp32" but model params aren’t FP32.

Helps users avoid unexpected mixed dtypes in the exported graph.

If you prefer a lightweight approach, add a warning:

     try:
         # TODO: Single-precision torch model assumed
         param_dtype = next(model.parameters()).dtype
     except StopIteration:
         param_dtype = torch.float32
-    if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
+    if weights_dtype == "fp32" and param_dtype != torch.float32:
+        print("Warning: weights_dtype='fp32' requested but model params are "
+              f"{param_dtype}. Export may contain non-FP32 tensors.")
+    if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:

Alternatively, we can temporarily upcast a copy of the model before export—say the word if you want that patch.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 512dbb7 and b5f912e.

📒 Files selected for processing (1)
  • modelopt/torch/_deploy/utils/torch_onnx.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)

396-398: LGTM: correctly bypasses autocast for FP32 export.

This aligns with the PR goal: FP32 requests won’t be mixed-precision due to autocast. Also preserves the existing exclusions for FP4/MXFP8.

One nuance: this only disables autocast; it doesn’t upcast model params. If a user supplies a half-precision model with weights_dtype="fp32", the export may still contain non-FP32 tensors. Is that intended? If not, we should either document it or upcast before export.

Copy link

codecov bot commented Sep 9, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.88%. Comparing base (85b309f) to head (ad8efc2).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #304   +/-   ##
=======================================
  Coverage   73.87%   73.88%           
=======================================
  Files         172      172           
  Lines       17439    17439           
=======================================
+ Hits        12883    12884    +1     
+ Misses       4556     4555    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@i-riyad i-riyad enabled auto-merge (squash) September 9, 2025 22:53
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)

396-400: Align autocast dtype with requested weights dtype and guard for CPU-only environments.

Today autocast defaults to FP16 on CUDA. When weights_dtype == "bf16", aligning the autocast dtype avoids needless casts/mismatch; also, hard-coding CUDA can error on CPU-only runs. Suggest:

 use_torch_autocast = not (
     is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
 )
-autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
+autocast_dtype = torch.bfloat16 if weights_dtype == "bf16" else torch.float16
+autocast = (
+    torch.autocast("cuda", dtype=autocast_dtype)
+    if use_torch_autocast and torch.cuda.is_available()
+    else nullcontext()
+)

Optionally, if INT4 export has similar EP issues, consider also disabling autocast for INT4:

-    is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
+    is_fp4_quantized(model) or is_mxfp8_quantized(model) or is_int4_quantized(model) or weights_dtype == "fp32"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b5f912e and ad8efc2.

📒 Files selected for processing (1)
  • modelopt/torch/_deploy/utils/torch_onnx.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)

396-400: Autocast disabled for FP32 export — this directly addresses the ONNX Runtime multi-EP failure.

Including weights_dtype == "fp32" in the condition ensures pure FP32 export avoids unintended mixed-precision compute during tracing. LGTM.

@i-riyad i-riyad merged commit 4716131 into main Sep 10, 2025
22 checks passed
@i-riyad i-riyad deleted the rislam/autocast-avoid-for-fp32 branch September 10, 2025 16:28
jingyu-ml pushed a commit that referenced this pull request Sep 10, 2025
benchislett pushed a commit that referenced this pull request Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants