Skip to content

Conversation

aboubezari
Copy link

@aboubezari aboubezari commented Sep 9, 2025

What does this PR do?

Type of change: Bug fix

Overview: If there a cast node connecting an input directly to an output, then the output will be totally disconnected due to naming issues. This fix will create specialized cast nodes for such edge cases and avoid removing them in the initial pass.

Usage

Autocast precision converter

Testing

Added a unittest that fails before my change, and passes after my fix.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Summary by CodeRabbit

  • Bug Fixes

    • Preserve input-to-output Cast operations during precision conversion when keeping I/O types by duplicating and re-routing IO casts to avoid accidental removal.
    • Ensure FP16/BF16 conversion stability by skipping critical IO casts while still removing redundant casts, keeping consumer connections intact and graph outputs correct.
  • Tests

    • Added unit test and fixture validating models with casted outputs and mixed-precision nodes for FP16 and BF16, asserting model validity after conversion.

@aboubezari aboubezari requested a review from a team as a code owner September 9, 2025 00:47
@aboubezari aboubezari requested a review from ajrasane September 9, 2025 00:47
Copy link

copy-pr-bot bot commented Sep 9, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Detects Cast nodes that directly connect a model input to a model output, duplicates and re-routes them so the IO-facing Cast is preserved, excludes those duplicates from generic cast-removal, and adds a unit test fixture (model_with_casted_input_to_output) plus test_casted_input_to_output_model validating conversion for fp16/bf16 with keep_io_types=True.

Changes

Cohort / File(s) Summary
Autocast precision converter logic
modelopt/onnx/autocast/precisionconverter.py
Updates _remove_preexisting_casts to collect model input/output names, detect Cast nodes whose input is a model input and output is a model output, duplicate those Casts (new_cast) for IO, rename original Cast outputs and rewire consumers to a suffixed output name, record original Cast names to skip, append duplicates to the graph, convert skips to a set, and then run the existing removal pass while excluding the recorded casts.
Unit tests for IO-cast handling
tests/unit/onnx/autocast/test_precisionconverter.py
Adds fixture model_with_casted_input_to_output() that builds an ONNX model where one output Y1 is produced by a Cast from input X and another output Y2 is produced via Add ops; adds test_casted_input_to_output_model (parameterized over low_precision_type = "fp16","bf16") which runs PrecisionConverter with keep_io_types=True and asserts the converted model is valid.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant PC as PrecisionConverter
  participant G as ONNX Graph
  participant N as Cast Node
  participant C as Consumers

  PC->>G: enumerate nodes, collect model input/output names
  G-->>PC: node list + model IO names

  alt Cast input ∈ model_inputs and output ∈ model_outputs
    PC->>N: create duplicate Cast (new_cast) with same inputs/outputs/dtype
    PC->>N: rename original Cast output (add suffix)
    PC->>C: rewire original consumers → suffixed output
    PC->>PC: record original cast name in casts_to_skip
  else other Cast nodes
    PC->>PC: mark removable Cast candidates (fp16/fp32 rules, ignore initializers)
  end

  PC->>G: append duplicated IO Casts
  PC->>G: remove removable Casts excluding casts_to_skip
  G-->>PC: return updated graph
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly and accurately describes the core bug fix regarding input-to-output casting in the Autocast precision converter, making it clear to reviewers what the change addresses without unnecessary detail.
Description Check ✅ Passed The description directly relates to the changeset by outlining the bug fix scope, describing the edge case and the approach taken, and noting the addition of a unit test that verifies the fix, making it informative and on-topic.

Poem

I sniffed the graph where inputs meet the sky,
A Cast from X to Y1 gave me a hop and a wink.
I cloned and nudged the wires just right,
Kept IO types safe through day and night.
Hopping home with tests all green, I blink. 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 01308d6 and 9363b09.

📒 Files selected for processing (1)
  • tests/unit/onnx/autocast/test_precisionconverter.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/onnx/autocast/test_precisionconverter.py
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

624-634: Preexisting-cast removal misses BF16 targets

Comment says FP16/BF16/FP32 casts are removed, but is_fp_cast only matches to ∈ {FLOAT16, FLOAT}. Include BF16 as a target to honor the contract.

-                is_fp_cast = cast_to_type in [
-                    onnx.TensorProto.FLOAT16,
-                    onnx.TensorProto.FLOAT,
-                ] and cast_from_type in [
+                is_fp_cast = cast_to_type in [
+                    onnx.TensorProto.FLOAT16,
+                    onnx.TensorProto.FLOAT,
+                    onnx.TensorProto.BFLOAT16,
+                ] and cast_from_type in [
                     onnx.TensorProto.FLOAT16,
                     onnx.TensorProto.FLOAT,
                     onnx.TensorProto.BFLOAT16,
                 ]

641-644: Guard for output-producing casts is ineffective

This condition checks if BOTH the cast input and output are network outputs, which never happens. It should keep casts that produce a network output.

-                    # Keep cast nodes that are necessary producers of network outputs
-                    if any(node.input[0] == out.name for out in self.model.graph.output) and any(
-                        node.output[0] == out.name for out in self.model.graph.output
-                    ):
+                    # Keep casts that produce a network output
+                    if node.output[0] in model_output_names:
                         continue
🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (1)

618-621: Insert duplicate IO-bridge casts deterministically (top of graph)

Appending at the tail can shuffle topo order. Inserting at index 0 is more stable for input-driven casts.

-        for cast in casts_to_add:
-            self.model.graph.node.append(cast)
+        for cast in casts_to_add:
+            self.model.graph.node.insert(0, cast)
tests/unit/onnx/autocast/test_precisionconverter.py (2)

1068-1071: Don’t write artifacts to /tmp in unit tests

The saved model isn’t used. Remove to keep tests hermetic.

-    onnx.save(model, "/tmp/model_with_casted_output.onnx")

1076-1091: Strengthen assertions: verify Y1 connectivity and dtype

Only checking onnx.checker is weak. Assert that Y1 remains produced by a Cast and retains FP32 (keep_io_types=True).

     converted_model = converter.convert(
         high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"]
     )
-    onnx.checker.check_model(converted_model)
+    onnx.checker.check_model(converted_model)
+    # Y1 should remain connected and produced by a Cast
+    y1_producers = utils.get_producer_nodes(converted_model, "Y1")
+    assert len(y1_producers) == 1
+    assert y1_producers[0].op_type == "Cast"
+    # keep_io_types=True -> FP32 I/O preserved
+    y1_vi = next(y for y in converted_model.graph.output if y.name == "Y1")
+    assert y1_vi.type.tensor_type.elem_type == TensorProto.FLOAT
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 512dbb7 and 16d5875.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/unit/onnx/autocast/test_precisionconverter.py (2)
modelopt/onnx/utils.py (1)
  • check_model (557-569)
modelopt/onnx/autocast/precisionconverter.py (1)
  • convert (113-202)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

607-617: Confirm intent: skipping only the duplicated “new_cast”, not the renamed original

casts_to_skip holds the original name (now assigned to new_cast) and will not skip the renamed original (..._io_special_case). If that original gets removed, _bypass_cast_node will reconnect its consumers directly to the model input. Is that intended? If not, add the renamed name to the skip list as well right after renaming.

-                    casts_to_skip.append(node.name)
+                    casts_to_skip.append(node.name)
                     casts_to_add.append(new_cast)
                     # Now adjust the old cast's name, consumers and producers
                     node.name = f"{node.name}_io_special_case"
+                    casts_to_skip.append(node.name)  # keep the internal IO-special-case cast as well

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: aboubezari <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
modelopt/onnx/autocast/precisionconverter.py (3)

600-606: Also preserve optional Cast attributes (e.g., saturate) to avoid semantic drift

If the model uses newer opsets where Cast may carry optional attributes (like saturate), the duplicate should copy them.

Apply:

-                    new_cast = helper.make_node(
-                        "Cast",
-                        name=node.name,
-                        inputs=[node.input[0]],
-                        outputs=[node.output[0]],
-                        to=utils.get_cast_to_type(node),
-                    )
+                    # Copy optional attributes (e.g., 'saturate' in newer opsets)
+                    saturate = next((a.i for a in node.attribute if a.name == "saturate"), None)
+                    cast_attrs = {"to": utils.get_cast_to_type(node)}
+                    if saturate is not None:
+                        cast_attrs["saturate"] = saturate
+                    new_cast = helper.make_node(
+                        "Cast",
+                        name=node.name,
+                        inputs=[node.input[0]],
+                        outputs=[node.output[0]],
+                        **cast_attrs,
+                    )

618-621: Insert duplicate Cast adjacent to the original for better locality and readability

Appending at the end works but scatters IO nodes. Insert near the renamed source node to keep topology readable.

-        for cast in casts_to_add:
-            self.model.graph.node.append(cast)
+        # Preserve locality: insert duplicates next to their originals
+        for cast in casts_to_add:
+            target_idx = -1
+            for i, n in enumerate(self.model.graph.node):
+                if n.name == f"{cast.name}_io_special_case":
+                    target_idx = i
+                    break
+            if target_idx >= 0:
+                self.model.graph.node.insert(target_idx, cast)
+            else:
+                # Fallback to prepend to avoid end-append reordering
+                self.model.graph.node.insert(0, cast)

592-596: Use a set for casts_to_skip from the start

Minor nit for clarity and O(1) membership checks.

-        casts_to_skip = []
+        casts_to_skip: set[str] = set()
         # Add casts as a separate step to avoid modifying the graph while iterating over it
         casts_to_add = []
@@
-                    casts_to_skip.append(node.name)
+                    casts_to_skip.add(node.name)
@@
-        casts_to_skip = set(casts_to_skip)
+        # already a set

Also applies to: 620-621

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 16d5875 and 01308d6.

📒 Files selected for processing (1)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

588-617: Solid IO-cast preservation strategy

Duplicating the IO-facing Cast, renaming the original, and rewiring consumers avoids disconnecting outputs while still enabling generic cast cleanup. This addresses the edge case cleanly.

Signed-off-by: Ali Boubezari <[email protected]>
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails when keep_io_types=False due to a graph input becoming a graph output directly, which violates the assertion in ModelOpt that all original input and output names should be maintained in the quantized model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants