diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 4d6917fd2..210280067 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -566,6 +566,15 @@ def convert_initializer( to_type=self.high_precision_type, ) + def _replace_tensor_name( + self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str + ) -> None: + """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name.""" + for consumer in consumers: + for idx, inp in enumerate(consumer.input): + if inp == original_tensor_name: + consumer.input[idx] = new_tensor_name + def _bypass_cast_node(self, node: onnx.NodeProto) -> None: # handling only a single input and output, as we only remove cast nodes assert len(node.input) == 1 @@ -573,21 +582,23 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None: input_tensor = node.input[0] output_tensor = node.output[0] - is_output_producer = False - # If removed cast node is producing a network output, we need to update the node producing the cast - # Network output name should not be changed - for output in self.model.graph.output: - if output.name == output_tensor: - is_output_producer = True - producers = utils.get_producer_nodes(self.model, input_tensor) - for producer in producers: - for i, prod_out in enumerate(producer.output): - if prod_out == input_tensor: - producer.output[i] = output_tensor - if ( - not is_output_producer - ): # Reconnect consumers of the cast output to use the cast input instead + # Check if the cast output is also a graph output + is_output_producer = any(output.name == output_tensor for output in self.model.graph.output) + + # If the removed cast node is producing a network output, update the producer of the cast input so + # the network output name is preserved. + if is_output_producer: + producers = utils.get_producer_nodes(self.model, input_tensor) + for producer in producers: + for i, prod_out in enumerate(producer.output): + if prod_out == input_tensor: + producer.output[i] = output_tensor + consumers = utils.get_consumer_nodes(self.model, prod_out) + if len(consumers) > 1: + self._replace_tensor_name(consumers, prod_out, output_tensor) + else: + # Reconnect consumers of the cast output to use the cast input instead consumers = utils.get_consumer_nodes(self.model, output_tensor) for consumer in consumers: for i, input_name in enumerate(consumer.input): diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 92614f915..a85bed287 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1 assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" + + +@pytest.fixture +def model_with_multiple_output_node_casted_to_output(): + """Create a model with a Cast node connecting a consumer with multiple outputs to a graph output.""" + # Create inputs and outputs + x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 2, 16, 16]) + x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 16, 16]) + x3 = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 4, 16, 16]) + y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [1, 5, 16, 16]) + y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [1, 9, 16, 16]) + + # Create computation nodes + concat_1_node = helper.make_node( + "Concat", + ["X1", "X2"], + ["concat_1_out"], + name="concat_1", + axis=1, + ) + concat_2_node = helper.make_node( + "Concat", + ["concat_1_out", "X3"], + ["Y2"], + name="concat_2", + axis=1, + ) + + # Create a Cast node between 'concat_1' and the graph output + cast_node = helper.make_node( + "Cast", + ["concat_1_out"], + ["Y1"], + name="cast_0", + to=TensorProto.FLOAT, + ) + + graph = helper.make_graph( + [concat_1_node, concat_2_node, cast_node], + "model_with_multiple_output_node_casted_to_output", + [x1, x2, x3], + [y1, y2], + [], + ) + + model = helper.make_model( + graph, producer_name="model_with_multiple_output_node_casted_to_output" + ) + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_multiple_output_node_casted_to_output( + model_with_multiple_output_node_casted_to_output, low_precision_type +): + model, value_info_map, initializer_map, node_to_init_map = ( + model_with_multiple_output_node_casted_to_output + ) + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + converted_model = converter.convert( + high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] + ) + onnx.checker.check_model(converted_model)