-
Notifications
You must be signed in to change notification settings - Fork 162
Avoid autocast at onnx export if fp32 model is desired #304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Riyad Islam <[email protected]>
WalkthroughThe ONNX export utility changes the autocast selection: autocast is disabled (nullcontext) when the model is FP4-quantized, MXFP8-quantized, or when Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant ExportFunc as get_onnx_bytes_and_metadata
participant Ctx as Context Manager
participant ONNX as ONNX Export
Caller->>ExportFunc: request ONNX bytes + metadata
ExportFunc->>ExportFunc: evaluate is_fp4_quantized / is_mxfp8_quantized / weights_dtype=="fp32"
alt FP4 or MXFP8 or FP32 weights
ExportFunc->>Ctx: use nullcontext()
Note right of Ctx: Autocast disabled (cpu-like path)
else
ExportFunc->>Ctx: use autocast("cuda")
Note right of Ctx: Autocast enabled (CUDA mixed precision)
end
ExportFunc->>ONNX: run export under chosen context
ONNX-->>ExportFunc: ONNX bytes + metadata
ExportFunc-->>Caller: return results
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
modelopt/torch/_deploy/utils/torch_onnx.py (4)
399-403
: Make autocast device-aware and dtype-aligned; avoid CUDA-only assumption.
torch.autocast("cuda")
will raise on CPU-only environments and ignores BF16 requests. Use the model’s device and set dtype accordingly.Apply:
- autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() + # Choose a safe device/dtype for autocast + try: + device_type = next(model.parameters()).device.type + except StopIteration: + device_type = "cuda" if torch.cuda.is_available() else "cpu" + if use_torch_autocast and device_type in ("cuda", "cpu"): + dtype = torch.bfloat16 if weights_dtype == "bf16" else torch.float16 + autocast = torch.autocast(device_type, dtype=dtype) + else: + autocast = nullcontext()
354-356
: Clarifyweights_dtype
semantics in the docstring.Explicitly state that FP32 disables autocast and that no upcast occurs automatically.
- weights_dtype: The dtype of the weights in the onnx model. + weights_dtype: Target ONNX weight dtype ("fp32", "fp16", "bf16"). + When "fp32", torch.autocast is disabled during export to avoid mixed-precision casts. + Note: this does not upcast model parameters automatically; ensure your model params are + float32 if a pure FP32 graph is desired.
333-336
: Avoid mutable default fordynamic_axes
; gate insertion only when provided.Prevents surprises from a shared dict default and keeps kwargs compact when not needed.
- dynamic_axes: dict = {}, + dynamic_axes: dict | None = None, @@ - if not dynamo_export and Version(torch.__version__) >= Version("2.8"): - additional_kwargs["dynamic_axes"] = dynamic_axes + if ( + not dynamo_export + and Version(torch.__version__) >= Version("2.8") + and dynamic_axes is not None + ): + additional_kwargs["dynamic_axes"] = dynamic_axesAlso applies to: 435-437
485-491
: Optional: warn (or upcast) whenweights_dtype="fp32"
but model params aren’t FP32.Helps users avoid unexpected mixed dtypes in the exported graph.
If you prefer a lightweight approach, add a warning:
try: # TODO: Single-precision torch model assumed param_dtype = next(model.parameters()).dtype except StopIteration: param_dtype = torch.float32 - if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: + if weights_dtype == "fp32" and param_dtype != torch.float32: + print("Warning: weights_dtype='fp32' requested but model params are " + f"{param_dtype}. Export may contain non-FP32 tensors.") + if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:Alternatively, we can temporarily upcast a copy of the model before export—say the word if you want that patch.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/_deploy/utils/torch_onnx.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)
396-398
: LGTM: correctly bypasses autocast for FP32 export.This aligns with the PR goal: FP32 requests won’t be mixed-precision due to autocast. Also preserves the existing exclusions for FP4/MXFP8.
One nuance: this only disables autocast; it doesn’t upcast model params. If a user supplies a half-precision model with
weights_dtype="fp32"
, the export may still contain non-FP32 tensors. Is that intended? If not, we should either document it or upcast before export.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #304 +/- ##
=======================================
Coverage 73.87% 73.88%
=======================================
Files 172 172
Lines 17439 17439
=======================================
+ Hits 12883 12884 +1
+ Misses 4556 4555 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)
396-400
: Align autocast dtype with requested weights dtype and guard for CPU-only environments.Today autocast defaults to FP16 on CUDA. When
weights_dtype == "bf16"
, aligning the autocast dtype avoids needless casts/mismatch; also, hard-coding CUDA can error on CPU-only runs. Suggest:use_torch_autocast = not ( is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" ) -autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() +autocast_dtype = torch.bfloat16 if weights_dtype == "bf16" else torch.float16 +autocast = ( + torch.autocast("cuda", dtype=autocast_dtype) + if use_torch_autocast and torch.cuda.is_available() + else nullcontext() +)Optionally, if INT4 export has similar EP issues, consider also disabling autocast for INT4:
- is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" + is_fp4_quantized(model) or is_mxfp8_quantized(model) or is_int4_quantized(model) or weights_dtype == "fp32"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/_deploy/utils/torch_onnx.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)
396-400
: Autocast disabled for FP32 export — this directly addresses the ONNX Runtime multi-EP failure.Including
weights_dtype == "fp32"
in the condition ensures pure FP32 export avoids unintended mixed-precision compute during tracing. LGTM.
Signed-off-by: Riyad Islam <[email protected]> Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Riyad Islam <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview: ONNX models exported with torch.autocast enabled fail when multiple execution providers are passed to the ONNX Runtime inference session. ModelOpt enables autocast by default for INT8/FP8 ONNX export from torch models. This PR enables skipping of autocast if fp32 model is desired from export.
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit