-
Notifications
You must be signed in to change notification settings - Fork 162
[5504719] Fix bypassing of 'Cast' connecting a consumer with multiple outputs a… #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[5504719] Fix bypassing of 'Cast' connecting a consumer with multiple outputs a… #309
Conversation
WalkthroughAdds a private helper to rewrite consumer input tensor names and adjusts PrecisionConverter's cast-bypass logic to retarget multiple consumers when a producer output is also a graph output. Adds a fixture and a parameterized test (duplicated insertion) covering a Cast consumed by multiple outputs. No public API changes. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PC as PrecisionConverter
participant P as Producer
participant C as Cast
participant Cs as Consumers (1..n)
note over PC: Identify cast-bypass scenario
P->>C: produces prod_out
C->>Cs: cast consumes prod_out -> cast_out
PC->>PC: fetch consumers of prod_out
alt prod_out is also a graph output and multiple consumers
note right of PC: Rewrite each consumer input prod_out -> cast_out
PC->>Cs: _replace_tensor_name(consumers, prod_out, cast_out)
else single/zero consumer or not a graph output
note right of PC: Reconnect consumers of cast_out to cast input (existing behavior)
end
Cs->>Cs: downstream nodes now reference updated tensor names
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
560-565
: Nice extraction; add types and return count (minor).Type hints + a small return value make this helper safer and easier to reuse/debug.
- def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name): - for consumer in consumers: - for idx, inp in enumerate(consumer.input): - if inp == original_tensor_name: - consumer.input[idx] = new_tensor_name + def _replace_tensor_name( + self, + consumers: list[onnx.NodeProto], + original_tensor_name: str, + new_tensor_name: str, + ) -> int: + """Replace occurrences of original_tensor_name in the given consumers' inputs with new_tensor_name. + Returns the number of replacements performed.""" + replaced = 0 + for consumer in consumers: + for idx, inp in enumerate(consumer.input): + if inp == original_tensor_name: + consumer.input[idx] = new_tensor_name + replaced += 1 + return replaced
585-587
: Retarget only non-cast consumers to avoid transient self-referencing.Exclude the cast node itself when rewriting; prevents a temporary state where the cast’s input equals the network-output name (while it still exists).
Also consider the edge case when the cast’s input is a graph input (no producer nodes): today the branch does nothing, which can leave the graph output disconnected after removal. Either skip removal in that case or rewrite the graph output to the input tensor (or insert an Identity).
- consumers = utils.get_consumer_nodes(self.model, prod_out) - if len(consumers) > 1: - self._replace_tensor_name(consumers, prod_out, output_tensor) + consumers = utils.get_consumer_nodes(self.model, prod_out) + other_consumers = [c for c in consumers if c != node] + if other_consumers: + self._replace_tensor_name(other_consumers, prod_out, output_tensor)To verify the no-producer case quickly, build a minimal ONNX with: graph input -> Cast -> graph output, plus a second consumer of the cast input; run through _remove_preexisting_casts and confirm: (a) graph outputs remain connected; (b) the second consumer’s input was retargeted.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/autocast/precisionconverter.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #309 +/- ##
==========================================
+ Coverage 73.87% 73.88% +0.01%
==========================================
Files 172 172
Lines 17439 17444 +5
==========================================
+ Hits 12883 12889 +6
+ Misses 4556 4555 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
1084-1104
: Strengthen regression: assert rewiring and removal of the bypassed Cast.Right now the test only checks model validity. Add assertions to prove the fix: (a) cast_0 removed, (b) concat_1 now produces Y1, (c) concat_2 consumes Y1, (d) concat_1_out is no longer referenced.
-@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_multiple_output_node_casted_to_output( - model_with_multiple_output_node_casted_to_output, low_precision_type -): +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_multiple_output_node_casted_to_output( + model_with_multiple_output_node_casted_to_output, low_precision_type +): @@ - converted_model = converter.convert( + converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] ) + # Assert fix behavior + assert not any(n.name == "cast_0" for n in converted_model.graph.node) + concat1 = next(n for n in converted_model.graph.node if n.name == "concat_1") + assert concat1.output[0] == "Y1" + concat2 = next(n for n in converted_model.graph.node if n.name == "concat_2") + assert concat2.input[0] == "Y1" + # No dangling references to the original tensor name + assert len(utils.get_consumer_nodes(converted_model, "concat_1_out")) == 0 + assert all("concat_1_out" not in n.output for n in converted_model.graph.node) onnx.checker.check_model(converted_model)
🧹 Nitpick comments (3)
tests/unit/onnx/autocast/test_precisionconverter.py (3)
1029-1031
: Clarify fixture docstring (producer vs. consumer).It’s the producer output with multiple consumers (concat_1_out -> cast_0 and concat_2), not “a consumer with multiple outputs.”
-def model_with_multiple_output_node_casted_to_output(): - """Create a model with a Cast node connecting a consumer with multiple outputs to a graph output.""" +def model_with_multiple_output_node_casted_to_output(): + """Model where a producer output (concat_1_out) has multiple consumers (cast_0 -> Y1 and concat_2), and cast_0 feeds graph output Y1."""
1055-1061
: Make the intent explicit: Cast is a no-op on purpose.A short comment helps future readers understand why
to=TensorProto.FLOAT
is used.cast_node = helper.make_node( "Cast", ["concat_1_out"], ["Y1"], name="cast_0", - to=TensorProto.FLOAT, + to=TensorProto.FLOAT, # no-op cast to exercise bypass logic when cast feeds a graph output )
1084-1099
: Optional: cover both keep_io_types modes.Parametrize keep_io_types to ensure bypass behavior holds regardless of I/O type policy.
-@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_multiple_output_node_casted_to_output( - model_with_multiple_output_node_casted_to_output, low_precision_type -): +@pytest.mark.parametrize("keep_io_types", [True, False]) +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_multiple_output_node_casted_to_output( + model_with_multiple_output_node_casted_to_output, keep_io_types, low_precision_type +): @@ - converter = PrecisionConverter( + converter = PrecisionConverter( model, value_info_map, initializer_map, node_to_init_map, - keep_io_types=True, + keep_io_types=keep_io_types, low_precision_type=low_precision_type, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/unit/onnx/autocast/test_precisionconverter.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
🔇 Additional comments (1)
tests/unit/onnx/autocast/test_precisionconverter.py (1)
1028-1104
: No duplicate fixture or test definitions detected. The verification script confirms each function appears exactly once intest_precisionconverter.py
.
8c0286c
to
b921811
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your reply! I've checked this PR with the other one and there doesn't seem to be any issues, as both unittests pass. There's one issue on with #305 when I've also simplified the |
…nd the model's output Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
36f4ca5
to
1f8fde5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
569-574
: Add type hints + return count; guard no-op renamesMinor polish that helps readability and future assertions. Also avoids work if names are equal.
-def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name): - for consumer in consumers: - for idx, inp in enumerate(consumer.input): - if inp == original_tensor_name: - consumer.input[idx] = new_tensor_name +def _replace_tensor_name( + self, + consumers: list[onnx.NodeProto], + original_tensor_name: str, + new_tensor_name: str, +) -> int: + """Replace occurrences of original_tensor_name in consumer inputs with new_tensor_name. + Returns the number of replacements performed.""" + if original_tensor_name == new_tensor_name: + return 0 + replaced = 0 + for consumer in consumers: + for idx, inp in enumerate(consumer.input): + if inp == original_tensor_name: + consumer.input[idx] = new_tensor_name + replaced += 1 + return replaced
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/autocast/precisionconverter.py
(1 hunks)tests/unit/onnx/autocast/test_precisionconverter.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/onnx/autocast/test_precisionconverter.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
Signed-off-by: gcunhase <[email protected]>
… outputs a… (#309) Signed-off-by: gcunhase <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview: Fixed issue with the function bypassing 'Cast' nodes when they're connected to a consumer with multiple outputs and the model's output.
Usage
Autocast precision converter
Testing
Added unittest.
Before your PR is "Ready for review"
Additional Information
Related: #305
Summary by CodeRabbit
Bug Fixes
Refactor
Tests