-
Notifications
You must be signed in to change notification settings - Fork 162
[Autocast] Fix edge case casting input directly to output #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Autocast] Fix edge case casting input directly to output #305
Conversation
WalkthroughDetects Cast nodes that directly connect a model input to a model output, duplicates and re-routes them so the IO-facing Cast is preserved, excludes those duplicates from generic cast-removal, and adds a unit test fixture ( Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PC as PrecisionConverter
participant G as ONNX Graph
participant N as Cast Node
participant C as Consumers
PC->>G: enumerate nodes, collect model input/output names
G-->>PC: node list + model IO names
alt Cast input ∈ model_inputs and output ∈ model_outputs
PC->>N: create duplicate Cast (new_cast) with same inputs/outputs/dtype
PC->>N: rename original Cast output (add suffix)
PC->>C: rewire original consumers → suffixed output
PC->>PC: record original cast name in casts_to_skip
else other Cast nodes
PC->>PC: mark removable Cast candidates (fp16/fp32 rules, ignore initializers)
end
PC->>G: append duplicated IO Casts
PC->>G: remove removable Casts excluding casts_to_skip
G-->>PC: return updated graph
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
624-634
: Preexisting-cast removal misses BF16 targetsComment says FP16/BF16/FP32 casts are removed, but
is_fp_cast
only matchesto ∈ {FLOAT16, FLOAT}
. Include BF16 as a target to honor the contract.- is_fp_cast = cast_to_type in [ - onnx.TensorProto.FLOAT16, - onnx.TensorProto.FLOAT, - ] and cast_from_type in [ + is_fp_cast = cast_to_type in [ + onnx.TensorProto.FLOAT16, + onnx.TensorProto.FLOAT, + onnx.TensorProto.BFLOAT16, + ] and cast_from_type in [ onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT, onnx.TensorProto.BFLOAT16, ]
641-644
: Guard for output-producing casts is ineffectiveThis condition checks if BOTH the cast input and output are network outputs, which never happens. It should keep casts that produce a network output.
- # Keep cast nodes that are necessary producers of network outputs - if any(node.input[0] == out.name for out in self.model.graph.output) and any( - node.output[0] == out.name for out in self.model.graph.output - ): + # Keep casts that produce a network output + if node.output[0] in model_output_names: continue
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)
618-621
: Insert duplicate IO-bridge casts deterministically (top of graph)Appending at the tail can shuffle topo order. Inserting at index 0 is more stable for input-driven casts.
- for cast in casts_to_add: - self.model.graph.node.append(cast) + for cast in casts_to_add: + self.model.graph.node.insert(0, cast)tests/unit/onnx/autocast/test_precisionconverter.py (2)
1068-1071
: Don’t write artifacts to /tmp in unit testsThe saved model isn’t used. Remove to keep tests hermetic.
- onnx.save(model, "/tmp/model_with_casted_output.onnx")
1076-1091
: Strengthen assertions: verify Y1 connectivity and dtypeOnly checking
onnx.checker
is weak. Assert that Y1 remains produced by a Cast and retains FP32 (keep_io_types=True).converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] ) - onnx.checker.check_model(converted_model) + onnx.checker.check_model(converted_model) + # Y1 should remain connected and produced by a Cast + y1_producers = utils.get_producer_nodes(converted_model, "Y1") + assert len(y1_producers) == 1 + assert y1_producers[0].op_type == "Cast" + # keep_io_types=True -> FP32 I/O preserved + y1_vi = next(y for y in converted_model.graph.output if y.name == "Y1") + assert y1_vi.type.tensor_type.elem_type == TensorProto.FLOAT
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py
(1 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
check_model
(557-569)modelopt/onnx/autocast/precisionconverter.py (1)
convert
(113-202)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
607-617
: Confirm intent: skipping only the duplicated “new_cast”, not the renamed original
casts_to_skip
holds the original name (now assigned tonew_cast
) and will not skip the renamed original (..._io_special_case
). If that original gets removed,_bypass_cast_node
will reconnect its consumers directly to the model input. Is that intended? If not, add the renamed name to the skip list as well right after renaming.- casts_to_skip.append(node.name) + casts_to_skip.append(node.name) casts_to_add.append(new_cast) # Now adjust the old cast's name, consumers and producers node.name = f"{node.name}_io_special_case" + casts_to_skip.append(node.name) # keep the internal IO-special-case cast as well
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: aboubezari <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (3)
600-606
: Also preserve optional Cast attributes (e.g., saturate) to avoid semantic driftIf the model uses newer opsets where Cast may carry optional attributes (like saturate), the duplicate should copy them.
Apply:
- new_cast = helper.make_node( - "Cast", - name=node.name, - inputs=[node.input[0]], - outputs=[node.output[0]], - to=utils.get_cast_to_type(node), - ) + # Copy optional attributes (e.g., 'saturate' in newer opsets) + saturate = next((a.i for a in node.attribute if a.name == "saturate"), None) + cast_attrs = {"to": utils.get_cast_to_type(node)} + if saturate is not None: + cast_attrs["saturate"] = saturate + new_cast = helper.make_node( + "Cast", + name=node.name, + inputs=[node.input[0]], + outputs=[node.output[0]], + **cast_attrs, + )
618-621
: Insert duplicate Cast adjacent to the original for better locality and readabilityAppending at the end works but scatters IO nodes. Insert near the renamed source node to keep topology readable.
- for cast in casts_to_add: - self.model.graph.node.append(cast) + # Preserve locality: insert duplicates next to their originals + for cast in casts_to_add: + target_idx = -1 + for i, n in enumerate(self.model.graph.node): + if n.name == f"{cast.name}_io_special_case": + target_idx = i + break + if target_idx >= 0: + self.model.graph.node.insert(target_idx, cast) + else: + # Fallback to prepend to avoid end-append reordering + self.model.graph.node.insert(0, cast)
592-596
: Use a set for casts_to_skip from the startMinor nit for clarity and O(1) membership checks.
- casts_to_skip = [] + casts_to_skip: set[str] = set() # Add casts as a separate step to avoid modifying the graph while iterating over it casts_to_add = [] @@ - casts_to_skip.append(node.name) + casts_to_skip.add(node.name) @@ - casts_to_skip = set(casts_to_skip) + # already a setAlso applies to: 620-621
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/autocast/precisionconverter.py
(1 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
588-617
: Solid IO-cast preservation strategyDuplicating the IO-facing Cast, renaming the original, and rewiring consumers avoids disconnecting outputs while still enabling generic cast cleanup. This addresses the edge case cleanly.
Signed-off-by: Ali Boubezari <[email protected]>
value_info_map, | ||
initializer_map, | ||
node_to_init_map, | ||
keep_io_types=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test fails when keep_io_types=False
due to a graph input becoming a graph output directly, which violates the assertion in ModelOpt that all original input and output names should be maintained in the quantized model.
What does this PR do?
Type of change: Bug fix
Overview: If there a cast node connecting an input directly to an output, then the output will be totally disconnected due to naming issues. This fix will create specialized cast nodes for such edge cases and avoid removing them in the initial pass.
Usage
Autocast precision converter
Testing
Added a unittest that fails before my change, and passes after my fix.
Before your PR is "Ready for review"
Summary by CodeRabbit
Bug Fixes
Tests