From b95458bcd4bb65ef7bc552483694454d9f0d02ec Mon Sep 17 00:00:00 2001 From: Ali Boubezari Date: Mon, 8 Sep 2025 16:17:37 -0700 Subject: [PATCH] [Autocast] Fix edge case with cast producing network output Signed-off-by: Ali Boubezari --- modelopt/onnx/autocast/precisionconverter.py | 5 ++ .../onnx/autocast/test_precisionconverter.py | 80 +++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 8b6fd545d..686dfb494 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -576,6 +576,11 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None: for i, prod_out in enumerate(producer.output): if prod_out == input_tensor: producer.output[i] = output_tensor + consumers = utils.get_consumer_nodes(self.model, input_tensor) + for consumer in consumers: + for i, input_name in enumerate(consumer.input): + if input_name == input_tensor: + consumer.input[i] = output_tensor if ( not is_output_producer ): # Reconnect consumers of the cast output to use the cast input instead diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 92614f915..e1bdaa8ba 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1023,3 +1023,83 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1 assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" + + +@pytest.fixture +def model_with_casted_output(): + """Create a model with an output produced by a Cast node.""" + # Create input and outputs + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output + y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output + + # Create constant value + const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + + # Create constant node + const_node = helper.make_node( + "Constant", + [], + ["const"], + name="const", + value=numpy_helper.from_array(const, name="const_value"), + ) + + # Create computation nodes + add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1") + add2 = helper.make_node("Add", ["add1_out", "const"], ["add2_out"], name="add2") + + # Create cast nodes to higher precision (FLOAT32) + cast1 = helper.make_node("Cast", ["add1_out"], ["Y1"], name="cast1", to=TensorProto.FLOAT) + cast2 = helper.make_node("Cast", ["add2_out"], ["Y2"], name="cast2", to=TensorProto.FLOAT) + + graph = helper.make_graph( + [const_node, add1, add2, cast1, cast2], + "model_with_casted_output", + [x], + [y1, y2], + [], + ) + + model = helper.make_model(graph, producer_name="model_with_casted_output") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + onnx.save(model, "/tmp/model_with_casted_output.onnx") + + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("keep_io_types", [True, False]) +def test_casted_output_model(model_with_casted_output, low_precision_type, keep_io_types): + model, value_info_map, initializer_map, node_to_init_map = model_with_casted_output + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + low_precision_type=low_precision_type, + ) + + converted_model = converter.convert( + high_precision_nodes=["cast1", "cast2"], low_precision_nodes=["add1", "add2"] + ) + onnx.checker.check_model(converted_model) + + # Check that the output is casted to the correct precision + if keep_io_types: + assert converted_model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT + assert converted_model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT + else: + assert converted_model.graph.output[ + 0 + ].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type) + assert converted_model.graph.output[ + 1 + ].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)