diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 8b6fd545d..fe7c941be 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -559,6 +559,15 @@ def convert_initializer( def _bypass_cast_node(self, node: onnx.NodeProto) -> None: # handling only a single input and output, as we only remove cast nodes + """ + Bypass (remove) a Cast node by rewiring its producer(s) and consumer(s) in-place in the model graph. + + This function expects the provided node to be a Cast with exactly one input and one output (asserted). + If the Cast's output is a graph output, the graph output name must be preserved: producers that + originally wrote the Cast input and any consumers of that input are rewired to produce/use the + graph output name instead. Otherwise, consumers of the Cast output are rewired to consume the + Cast input directly. Modifies self.model.graph in-place. + """ assert len(node.input) == 1 assert len(node.output) == 1 @@ -576,6 +585,11 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None: for i, prod_out in enumerate(producer.output): if prod_out == input_tensor: producer.output[i] = output_tensor + consumers = utils.get_consumer_nodes(self.model, input_tensor) + for consumer in consumers: + for i, input_name in enumerate(consumer.input): + if input_name == input_tensor: + consumer.input[i] = output_tensor if ( not is_output_producer ): # Reconnect consumers of the cast output to use the cast input instead diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 92614f915..e5fd8504a 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1023,3 +1023,98 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add" assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1 assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add" + + +@pytest.fixture +def model_with_casted_output(): + """ + Create a tiny ONNX model whose final outputs are produced by Cast nodes. + + The graph: + - Input: "X" (float32, [2, 3]). + - A Constant tensor is added twice through Add nodes ("add1", "add2"). + - Two Cast nodes ("cast1", "cast2") consume Add outputs and produce the graph outputs "Y1" and "Y2" (cast to FLOAT). + - Model uses opset 20 and has shapes inferred before being returned. + + Returns: + tuple: (model, value_info_map, initializer_map, node_to_init_map) + - model (onnx.ModelProto): The checked ONNX model with inferred shapes. + - value_info_map (dict): Mapping from tensor name to ValueInfoProto. + - initializer_map (dict): Mapping from initializer name to TensorProto. + - node_to_init_map (dict): Mapping from node name to its related initializers. + """ + # Create input and outputs + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output + y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output + + # Create constant value + const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + + # Create constant node + const_node = helper.make_node( + "Constant", + [], + ["const"], + name="const", + value=numpy_helper.from_array(const, name="const_value"), + ) + + # Create computation nodes + add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1") + add2 = helper.make_node("Add", ["add1_out", "const"], ["add2_out"], name="add2") + + # Create cast nodes to higher precision (FLOAT32) + cast1 = helper.make_node("Cast", ["add1_out"], ["Y1"], name="cast1", to=TensorProto.FLOAT) + cast2 = helper.make_node("Cast", ["add2_out"], ["Y2"], name="cast2", to=TensorProto.FLOAT) + + graph = helper.make_graph( + [const_node, add1, add2, cast1, cast2], + "model_with_casted_output", + [x], + [y1, y2], + [], + ) + + model = helper.make_model(graph, producer_name="model_with_casted_output") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + onnx.save(model, "/tmp/model_with_casted_output.onnx") + + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("keep_io_types", [True, False]) +def test_casted_output_model(model_with_casted_output, low_precision_type, keep_io_types): + model, value_info_map, initializer_map, node_to_init_map = model_with_casted_output + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + low_precision_type=low_precision_type, + ) + + converted_model = converter.convert( + high_precision_nodes=["cast1", "cast2"], low_precision_nodes=["add1", "add2"] + ) + onnx.checker.check_model(converted_model) + + # Check that the output is casted to the correct precision + if keep_io_types: + assert converted_model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT + assert converted_model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT + else: + assert converted_model.graph.output[ + 0 + ].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type) + assert converted_model.graph.output[ + 1 + ].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)