Skip to content

Conversation

gcunhase
Copy link
Contributor

@gcunhase gcunhase commented Sep 9, 2025

What does this PR do?

Type of change: Bug fix

Overview: Fixed issue with the function bypassing 'Cast' nodes when they're connected to a consumer with multiple outputs and the model's output.

Usage

Autocast precision converter

Testing

Added unittest.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: N/A
  • Did you update Changelog?: No

Additional Information

Related: #305

Summary by CodeRabbit

  • Bug Fixes

    • Preserves correct tensor connections when bypassing redundant casts, preventing miswired outputs and conversion errors.
  • Refactor

    • Centralized downstream input-rewiring logic for clearer, more maintainable cast-bypass behavior without changing public APIs.
  • Tests

    • Added tests for models with multiple consumers and casted outputs to catch regressions (note: test additions were duplicated).

@gcunhase gcunhase self-assigned this Sep 9, 2025
@gcunhase gcunhase requested a review from a team as a code owner September 9, 2025 17:05
@gcunhase gcunhase requested a review from galagam September 9, 2025 17:05
Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds a private helper to rewrite consumer input tensor names and adjusts PrecisionConverter's cast-bypass logic to retarget multiple consumers when a producer output is also a graph output. Adds a fixture and a parameterized test (duplicated insertion) covering a Cast consumed by multiple outputs. No public API changes.

Changes

Cohort / File(s) Summary
Cast-bypass consumer retargeting
modelopt/onnx/autocast/precisionconverter.py
Added _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name) and updated _bypass_cast_node to detect when a cast's producer output is also a graph output and to rewrite consumer inputs from the original producer output to the cast output for multi-consumer cases. Preserves existing single-consumer reconnect behavior. No public signatures changed.
New tests: multiple-consumer cast scenario (duplicated)
tests/unit/onnx/autocast/test_precisionconverter.py
Added fixture model_with_multiple_output_node_casted_to_output() and test test_multiple_output_node_casted_to_output(...) (parameterized over ["fp16","bf16"]) that build and validate an ONNX model where a Cast output is consumed by multiple Concat outputs. The fixture and test were inserted twice (duplicate entries present).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant PC as PrecisionConverter
  participant P as Producer
  participant C as Cast
  participant Cs as Consumers (1..n)

  note over PC: Identify cast-bypass scenario
  P->>C: produces prod_out
  C->>Cs: cast consumes prod_out -> cast_out

  PC->>PC: fetch consumers of prod_out
  alt prod_out is also a graph output and multiple consumers
    note right of PC: Rewrite each consumer input prod_out -> cast_out
    PC->>Cs: _replace_tensor_name(consumers, prod_out, cast_out)
  else single/zero consumer or not a graph output
    note right of PC: Reconnect consumers of cast_out to cast input (existing behavior)
  end

  Cs->>Cs: downstream nodes now reference updated tensor names
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

I nibble at a Cast, then hop—
I swap the names and clear the stop.
Where many mouths once looked the same,
I point them true and mend the name.
A rabbit's tweak — the graph sings on. 🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f8fde5 and 53d4a10.

📒 Files selected for processing (1)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/autocast/precisionconverter.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly summarizes the primary change by specifying that it fixes the cast bypass logic when a consumer has multiple outputs and the cast node also feeds the model output, which directly matches the pull request’s core objective. It avoids vague terms and file lists, focusing on the key behavior adjustment in the precision converter. Although it includes the bug number prefix, this does not detract from its clarity or relevance.
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gcunhase gcunhase marked this pull request as draft September 9, 2025 17:05
Copy link

copy-pr-bot bot commented Sep 9, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

560-565: Nice extraction; add types and return count (minor).

Type hints + a small return value make this helper safer and easier to reuse/debug.

-    def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name):
-        for consumer in consumers:
-            for idx, inp in enumerate(consumer.input):
-                if inp == original_tensor_name:
-                    consumer.input[idx] = new_tensor_name
+    def _replace_tensor_name(
+        self,
+        consumers: list[onnx.NodeProto],
+        original_tensor_name: str,
+        new_tensor_name: str,
+    ) -> int:
+        """Replace occurrences of original_tensor_name in the given consumers' inputs with new_tensor_name.
+        Returns the number of replacements performed."""
+        replaced = 0
+        for consumer in consumers:
+            for idx, inp in enumerate(consumer.input):
+                if inp == original_tensor_name:
+                    consumer.input[idx] = new_tensor_name
+                    replaced += 1
+        return replaced

585-587: Retarget only non-cast consumers to avoid transient self-referencing.

Exclude the cast node itself when rewriting; prevents a temporary state where the cast’s input equals the network-output name (while it still exists).

Also consider the edge case when the cast’s input is a graph input (no producer nodes): today the branch does nothing, which can leave the graph output disconnected after removal. Either skip removal in that case or rewrite the graph output to the input tensor (or insert an Identity).

-                            consumers = utils.get_consumer_nodes(self.model, prod_out)
-                            if len(consumers) > 1:
-                                self._replace_tensor_name(consumers, prod_out, output_tensor)
+                            consumers = utils.get_consumer_nodes(self.model, prod_out)
+                            other_consumers = [c for c in consumers if c != node]
+                            if other_consumers:
+                                self._replace_tensor_name(other_consumers, prod_out, output_tensor)

To verify the no-producer case quickly, build a minimal ONNX with: graph input -> Cast -> graph output, plus a second consumer of the cast input; run through _remove_preexisting_casts and confirm: (a) graph outputs remain connected; (b) the second consumer’s input was retargeted.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a6fa34c and 434baf6.

📒 Files selected for processing (1)
  • modelopt/onnx/autocast/precisionconverter.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Copy link

codecov bot commented Sep 9, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.88%. Comparing base (4716131) to head (53d4a10).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #309      +/-   ##
==========================================
+ Coverage   73.87%   73.88%   +0.01%     
==========================================
  Files         172      172              
  Lines       17439    17444       +5     
==========================================
+ Hits        12883    12889       +6     
+ Misses       4556     4555       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gcunhase gcunhase marked this pull request as ready for review September 9, 2025 18:06
@gcunhase gcunhase requested review from i-riyad and ajrasane September 9, 2025 18:07
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)

1084-1104: Strengthen regression: assert rewiring and removal of the bypassed Cast.

Right now the test only checks model validity. Add assertions to prove the fix: (a) cast_0 removed, (b) concat_1 now produces Y1, (c) concat_2 consumes Y1, (d) concat_1_out is no longer referenced.

-@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
-def test_multiple_output_node_casted_to_output(
-    model_with_multiple_output_node_casted_to_output, low_precision_type
-):
+@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
+def test_multiple_output_node_casted_to_output(
+    model_with_multiple_output_node_casted_to_output, low_precision_type
+):
@@
-    converted_model = converter.convert(
+    converted_model = converter.convert(
         high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
     )
+    # Assert fix behavior
+    assert not any(n.name == "cast_0" for n in converted_model.graph.node)
+    concat1 = next(n for n in converted_model.graph.node if n.name == "concat_1")
+    assert concat1.output[0] == "Y1"
+    concat2 = next(n for n in converted_model.graph.node if n.name == "concat_2")
+    assert concat2.input[0] == "Y1"
+    # No dangling references to the original tensor name
+    assert len(utils.get_consumer_nodes(converted_model, "concat_1_out")) == 0
+    assert all("concat_1_out" not in n.output for n in converted_model.graph.node)
     onnx.checker.check_model(converted_model)
🧹 Nitpick comments (3)
tests/unit/onnx/autocast/test_precisionconverter.py (3)

1029-1031: Clarify fixture docstring (producer vs. consumer).

It’s the producer output with multiple consumers (concat_1_out -> cast_0 and concat_2), not “a consumer with multiple outputs.”

-def model_with_multiple_output_node_casted_to_output():
-    """Create a model with a Cast node connecting a consumer with multiple outputs to a graph output."""
+def model_with_multiple_output_node_casted_to_output():
+    """Model where a producer output (concat_1_out) has multiple consumers (cast_0 -> Y1 and concat_2), and cast_0 feeds graph output Y1."""

1055-1061: Make the intent explicit: Cast is a no-op on purpose.

A short comment helps future readers understand why to=TensorProto.FLOAT is used.

     cast_node = helper.make_node(
         "Cast",
         ["concat_1_out"],
         ["Y1"],
         name="cast_0",
-        to=TensorProto.FLOAT,
+        to=TensorProto.FLOAT,  # no-op cast to exercise bypass logic when cast feeds a graph output
     )

1084-1099: Optional: cover both keep_io_types modes.

Parametrize keep_io_types to ensure bypass behavior holds regardless of I/O type policy.

-@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
-def test_multiple_output_node_casted_to_output(
-    model_with_multiple_output_node_casted_to_output, low_precision_type
-):
+@pytest.mark.parametrize("keep_io_types", [True, False])
+@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
+def test_multiple_output_node_casted_to_output(
+    model_with_multiple_output_node_casted_to_output, keep_io_types, low_precision_type
+):
@@
-    converter = PrecisionConverter(
+    converter = PrecisionConverter(
         model,
         value_info_map,
         initializer_map,
         node_to_init_map,
-        keep_io_types=True,
+        keep_io_types=keep_io_types,
         low_precision_type=low_precision_type,
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 434baf6 and 8c0286c.

📒 Files selected for processing (1)
  • tests/unit/onnx/autocast/test_precisionconverter.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
🔇 Additional comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)

1028-1104: No duplicate fixture or test definitions detected. The verification script confirms each function appears exactly once in test_precisionconverter.py.

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fix_cast_output_5504719 branch from 8c0286c to b921811 Compare September 9, 2025 18:20
Copy link

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gcunhase This looks good, but might conflict with #305 .
Should be careful with the merge.
Also, if you have the time, consider a refactor to simplify the _bypass_cast_node function.

@gcunhase
Copy link
Contributor Author

@gcunhase This looks good, but might conflict with #305 . Should be careful with the merge. Also, if you have the time, consider a refactor to simplify the _bypass_cast_node function.

Thank you for your reply!

I've checked this PR with the other one and there doesn't seem to be any issues, as both unittests pass.

There's one issue on with #305 when keep_io_types=False is selected, but since it's specific for that use-case, I'll address it in that PR.

I've also simplified the _bypass_cast_node function.

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fix_cast_output_5504719 branch from 36f4ca5 to 1f8fde5 Compare September 10, 2025 19:39
@gcunhase gcunhase enabled auto-merge (squash) September 10, 2025 19:40
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

569-574: Add type hints + return count; guard no-op renames

Minor polish that helps readability and future assertions. Also avoids work if names are equal.

-def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name):
-    for consumer in consumers:
-        for idx, inp in enumerate(consumer.input):
-            if inp == original_tensor_name:
-                consumer.input[idx] = new_tensor_name
+def _replace_tensor_name(
+    self,
+    consumers: list[onnx.NodeProto],
+    original_tensor_name: str,
+    new_tensor_name: str,
+) -> int:
+    """Replace occurrences of original_tensor_name in consumer inputs with new_tensor_name.
+    Returns the number of replacements performed."""
+    if original_tensor_name == new_tensor_name:
+        return 0
+    replaced = 0
+    for consumer in consumers:
+        for idx, inp in enumerate(consumer.input):
+            if inp == original_tensor_name:
+                consumer.input[idx] = new_tensor_name
+                replaced += 1
+    return replaced
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 36f4ca5 and 1f8fde5.

📒 Files selected for processing (2)
  • modelopt/onnx/autocast/precisionconverter.py (1 hunks)
  • tests/unit/onnx/autocast/test_precisionconverter.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/onnx/autocast/test_precisionconverter.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait

@gcunhase gcunhase disabled auto-merge September 10, 2025 20:05
@gcunhase gcunhase enabled auto-merge (squash) September 10, 2025 20:12
@gcunhase gcunhase merged commit 4ea72e3 into NVIDIA:main Sep 10, 2025
22 checks passed
benchislett pushed a commit that referenced this pull request Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants