-
Notifications
You must be signed in to change notification settings - Fork 162
[Autocast] Fix edge case with cast producing network output #302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Autocast] Fix edge case with cast producing network output #302
Conversation
Signed-off-by: Ali Boubezari <[email protected]>
WalkthroughExtends PrecisionConverter._bypass_cast_node to also rewire consumers of the pre-cast input when the removed Cast produced a model output. Adds unit tests covering models whose outputs are Cast nodes, parameterized over low-precision type and keep-io-types. Test file includes duplicated fixture and test blocks. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PC as PrecisionConverter
participant G as Graph
participant P as Producer(s)
participant C as Cast (to be removed)
participant DC as Downstream Consumer(s)
participant IC as InputTensor Consumer(s)
Note over PC,G: _bypass_cast_node (is_output_producer == true)
PC->>G: Identify C, its input_tensor and output_tensor
PC->>G: Rewire P to feed around C (bypass)
PC->>G: For DC consuming C.output_tensor<br/>update inputs to bypassed tensor
rect rgba(230,245,255,0.6)
Note over PC,G: New step in this branch
PC->>G: Find IC consuming input_tensor
PC->>G: Replace input_tensor with output_tensor in IC inputs
end
PC->>G: Remove C
G-->>PC: Graph updated
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Flag potential breaking changes that are not documented:
1. Identify changes to public APIs/exports, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints (including removed/renamed items and changes to types, required params, return values, defaults, or behavior).
2. Ignore purely internal/private changes (e.g., code not exported from package entry points or marked internal).
3. Verify documentation exists: a "Breaking Change" section in the PR description and updates to CHANGELOG.md. Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/onnx/autocast/precisionconverter.py (1)
579-583
: Avoid rewiring the Cast node itself (small safety nit).When rewriting consumers of the pre-cast tensor, skip the Cast node we're about to remove to avoid a transient self-loop (input==output). It’s harmless since the node is deleted, but skipping reduces churn.
- for consumer in consumers: + for consumer in consumers: + if consumer is node: + # Skip the Cast being removed; it will be deleted shortly. + continue for i, input_name in enumerate(consumer.input): if input_name == input_tensor: consumer.input[i] = output_tensortests/unit/onnx/autocast/test_precisionconverter.py (1)
1071-1071
: Remove unnecessary filesystem write in fixture.Writing the model to /tmp during tests is unnecessary and can break in sandboxed CI. Please drop it.
- onnx.save(model, "/tmp/model_with_casted_output.onnx")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py
(1 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model
(557-569)modelopt/onnx/autocast/precisionconverter.py (2)
PrecisionConverter
(55-1037)convert
(113-202)
🔇 Additional comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
1028-1106
: No duplicate definitions found formodel_with_casted_output
ortest_casted_output_model
. Verified exactly one fixture and one test definition in this file.
@coderabbitai generate docstrings |
✅ Actions performedDocstrings generation started. |
Docstrings generation was requested by @aboubezari. * #302 (comment) The following files were modified: * `modelopt/onnx/autocast/precisionconverter.py` * `tests/unit/onnx/autocast/test_precisionconverter.py`
Note Generated docstrings for this pull request at #303 |
Fixed by #308, closing. |
What does this PR do?
Type of change: Bug Fix
Overview:
If there is a cast producing a network output as well as some other nodes connected in the graph, then autocast will create a disconnected graph.
Autocast will update the node producing the cast to instead produce the network output. However, it will not update the nodes that consume this node, leaving a disconnected graph. To fix, update the consumers with the new, updated output
Usage
Autocast precision converter
Testing
Added a unittest that fails before my change, and passes after my fix.
Before your PR is "Ready for review"
Summary by CodeRabbit
Bug Fixes
Tests