From 42c122d1313f63b9cb1956b67f3bb4b1e792f729 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:42:18 -0400 Subject: [PATCH 1/4] Fix bypassing of 'Cast' connecting a consumer with multiple outputs and the model's output Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 4d6917fd2..14a6b3726 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -566,6 +566,12 @@ def convert_initializer( to_type=self.high_precision_type, ) + def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name): + for consumer in consumers: + for idx, inp in enumerate(consumer.input): + if inp == original_tensor_name: + consumer.input[idx] = new_tensor_name + def _bypass_cast_node(self, node: onnx.NodeProto) -> None: # handling only a single input and output, as we only remove cast nodes assert len(node.input) == 1 @@ -585,6 +591,9 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None: for i, prod_out in enumerate(producer.output): if prod_out == input_tensor: producer.output[i] = output_tensor + consumers = utils.get_consumer_nodes(self.model, prod_out) + if len(consumers) > 1: + self._replace_tensor_name(consumers, prod_out, output_tensor) if ( not is_output_producer ): # Reconnect consumers of the cast output to use the cast input instead From 6323d220e7f2baa3e4635456dc383d2a724a7527 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:03:28 -0400 Subject: [PATCH 2/4] Added unittest Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- .../onnx/autocast/test_precisionconverter.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 92614f915..a85bed287 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1 assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" + + +@pytest.fixture +def model_with_multiple_output_node_casted_to_output(): + """Create a model with a Cast node connecting a consumer with multiple outputs to a graph output.""" + # Create inputs and outputs + x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 2, 16, 16]) + x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 16, 16]) + x3 = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 4, 16, 16]) + y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [1, 5, 16, 16]) + y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [1, 9, 16, 16]) + + # Create computation nodes + concat_1_node = helper.make_node( + "Concat", + ["X1", "X2"], + ["concat_1_out"], + name="concat_1", + axis=1, + ) + concat_2_node = helper.make_node( + "Concat", + ["concat_1_out", "X3"], + ["Y2"], + name="concat_2", + axis=1, + ) + + # Create a Cast node between 'concat_1' and the graph output + cast_node = helper.make_node( + "Cast", + ["concat_1_out"], + ["Y1"], + name="cast_0", + to=TensorProto.FLOAT, + ) + + graph = helper.make_graph( + [concat_1_node, concat_2_node, cast_node], + "model_with_multiple_output_node_casted_to_output", + [x1, x2, x3], + [y1, y2], + [], + ) + + model = helper.make_model( + graph, producer_name="model_with_multiple_output_node_casted_to_output" + ) + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_multiple_output_node_casted_to_output( + model_with_multiple_output_node_casted_to_output, low_precision_type +): + model, value_info_map, initializer_map, node_to_init_map = ( + model_with_multiple_output_node_casted_to_output + ) + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + converted_model = converter.convert( + high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] + ) + onnx.checker.check_model(converted_model) From 1f8fde51f00163f028960f8c9291915c3c8980e3 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:36:26 -0400 Subject: [PATCH 3/4] Refactor: simplified '_bypass_cast_node' function Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 33 ++++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 14a6b3726..907cfb11e 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -579,24 +579,23 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None: input_tensor = node.input[0] output_tensor = node.output[0] - is_output_producer = False - # If removed cast node is producing a network output, we need to update the node producing the cast - # Network output name should not be changed - for output in self.model.graph.output: - if output.name == output_tensor: - is_output_producer = True - producers = utils.get_producer_nodes(self.model, input_tensor) - for producer in producers: - for i, prod_out in enumerate(producer.output): - if prod_out == input_tensor: - producer.output[i] = output_tensor - consumers = utils.get_consumer_nodes(self.model, prod_out) - if len(consumers) > 1: - self._replace_tensor_name(consumers, prod_out, output_tensor) - if ( - not is_output_producer - ): # Reconnect consumers of the cast output to use the cast input instead + # Check if the cast output is also a graph output + is_output_producer = any(output.name == output_tensor for output in self.model.graph.output) + + # If the removed cast node is producing a network output, we need to update the node producing the cast, as + # the network output name should not be changed + if is_output_producer: + producers = utils.get_producer_nodes(self.model, input_tensor) + for producer in producers: + for i, prod_out in enumerate(producer.output): + if prod_out == input_tensor: + producer.output[i] = output_tensor + consumers = utils.get_consumer_nodes(self.model, prod_out) + if len(consumers) > 1: + self._replace_tensor_name(consumers, prod_out, output_tensor) + else: + # Reconnect consumers of the cast output to use the cast input instead consumers = utils.get_consumer_nodes(self.model, output_tensor) for consumer in consumers: for i, input_name in enumerate(consumer.input): From 53d4a10d0ca8ba90ab102d38a43be2513d0b78e4 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:11:46 -0400 Subject: [PATCH 4/4] nit: comment and function types Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 907cfb11e..210280067 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -566,7 +566,10 @@ def convert_initializer( to_type=self.high_precision_type, ) - def _replace_tensor_name(self, consumers, original_tensor_name, new_tensor_name): + def _replace_tensor_name( + self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str + ) -> None: + """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name.""" for consumer in consumers: for idx, inp in enumerate(consumer.input): if inp == original_tensor_name: @@ -583,8 +586,8 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None: # Check if the cast output is also a graph output is_output_producer = any(output.name == output_tensor for output in self.model.graph.output) - # If the removed cast node is producing a network output, we need to update the node producing the cast, as - # the network output name should not be changed + # If the removed cast node is producing a network output, update the producer of the cast input so + # the network output name is preserved. if is_output_producer: producers = utils.get_producer_nodes(self.model, input_tensor) for producer in producers: