@@ -1023,3 +1023,98 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
1023
1023
assert utils .get_consumer_nodes (converted_model , "const_scalar" )[0 ].op_type == "Add"
1024
1024
assert len (utils .get_consumer_nodes (converted_model , "const_array" )) == 1
1025
1025
assert utils .get_consumer_nodes (converted_model , "const_array" )[0 ].op_type == "Add"
1026
+
1027
+
1028
+ @pytest .fixture
1029
+ def model_with_casted_output ():
1030
+ """
1031
+ Create a tiny ONNX model whose final outputs are produced by Cast nodes.
1032
+
1033
+ The graph:
1034
+ - Input: "X" (float32, [2, 3]).
1035
+ - A Constant tensor is added twice through Add nodes ("add1", "add2").
1036
+ - Two Cast nodes ("cast1", "cast2") consume Add outputs and produce the graph outputs "Y1" and "Y2" (cast to FLOAT).
1037
+ - Model uses opset 20 and has shapes inferred before being returned.
1038
+
1039
+ Returns:
1040
+ tuple: (model, value_info_map, initializer_map, node_to_init_map)
1041
+ - model (onnx.ModelProto): The checked ONNX model with inferred shapes.
1042
+ - value_info_map (dict): Mapping from tensor name to ValueInfoProto.
1043
+ - initializer_map (dict): Mapping from initializer name to TensorProto.
1044
+ - node_to_init_map (dict): Mapping from node name to its related initializers.
1045
+ """
1046
+ # Create input and outputs
1047
+ x = helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [2 , 3 ])
1048
+ y1 = helper .make_tensor_value_info ("Y1" , TensorProto .FLOAT , [2 , 3 ]) # Intermediate output
1049
+ y2 = helper .make_tensor_value_info ("Y2" , TensorProto .FLOAT , [2 , 3 ]) # Final output
1050
+
1051
+ # Create constant value
1052
+ const = np .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]], dtype = np .float32 )
1053
+
1054
+ # Create constant node
1055
+ const_node = helper .make_node (
1056
+ "Constant" ,
1057
+ [],
1058
+ ["const" ],
1059
+ name = "const" ,
1060
+ value = numpy_helper .from_array (const , name = "const_value" ),
1061
+ )
1062
+
1063
+ # Create computation nodes
1064
+ add1 = helper .make_node ("Add" , ["X" , "const" ], ["add1_out" ], name = "add1" )
1065
+ add2 = helper .make_node ("Add" , ["add1_out" , "const" ], ["add2_out" ], name = "add2" )
1066
+
1067
+ # Create cast nodes to higher precision (FLOAT32)
1068
+ cast1 = helper .make_node ("Cast" , ["add1_out" ], ["Y1" ], name = "cast1" , to = TensorProto .FLOAT )
1069
+ cast2 = helper .make_node ("Cast" , ["add2_out" ], ["Y2" ], name = "cast2" , to = TensorProto .FLOAT )
1070
+
1071
+ graph = helper .make_graph (
1072
+ [const_node , add1 , add2 , cast1 , cast2 ],
1073
+ "model_with_casted_output" ,
1074
+ [x ],
1075
+ [y1 , y2 ],
1076
+ [],
1077
+ )
1078
+
1079
+ model = helper .make_model (graph , producer_name = "model_with_casted_output" )
1080
+ model .opset_import [0 ].version = 20
1081
+ model .ir_version = 10
1082
+ onnx .checker .check_model (model )
1083
+
1084
+ model = onnx_utils .infer_shapes (model )
1085
+ value_info_map , initializer_map , node_to_init_map = utils .setup_mappings (model )
1086
+ onnx .save (model , "/tmp/model_with_casted_output.onnx" )
1087
+
1088
+ return model , value_info_map , initializer_map , node_to_init_map
1089
+
1090
+
1091
+ @pytest .mark .parametrize ("low_precision_type" , ["fp16" , "bf16" ])
1092
+ @pytest .mark .parametrize ("keep_io_types" , [True , False ])
1093
+ def test_casted_output_model (model_with_casted_output , low_precision_type , keep_io_types ):
1094
+ model , value_info_map , initializer_map , node_to_init_map = model_with_casted_output
1095
+
1096
+ converter = PrecisionConverter (
1097
+ model ,
1098
+ value_info_map ,
1099
+ initializer_map ,
1100
+ node_to_init_map ,
1101
+ keep_io_types = keep_io_types ,
1102
+ low_precision_type = low_precision_type ,
1103
+ )
1104
+
1105
+ converted_model = converter .convert (
1106
+ high_precision_nodes = ["cast1" , "cast2" ], low_precision_nodes = ["add1" , "add2" ]
1107
+ )
1108
+ onnx .checker .check_model (converted_model )
1109
+
1110
+ # Check that the output is casted to the correct precision
1111
+ if keep_io_types :
1112
+ assert converted_model .graph .output [0 ].type .tensor_type .elem_type == TensorProto .FLOAT
1113
+ assert converted_model .graph .output [1 ].type .tensor_type .elem_type == TensorProto .FLOAT
1114
+ else :
1115
+ assert converted_model .graph .output [
1116
+ 0
1117
+ ].type .tensor_type .elem_type == low_precision_onnx_type (low_precision_type )
1118
+ assert converted_model .graph .output [
1119
+ 1
1120
+ ].type .tensor_type .elem_type == low_precision_onnx_type (low_precision_type )
0 commit comments