@@ -30,6 +30,9 @@ def low_precision_onnx_type(low_precision_type_str):
30
30
return TensorProto .FLOAT16 if low_precision_type_str == "fp16" else TensorProto .BFLOAT16
31
31
32
32
33
+ LATEST_IR_VERSION_SUPPORTED_BY_ORT = 10
34
+
35
+
33
36
####################################################################################################
34
37
# Testing with a basic GEMM->Add->Relu graph
35
38
####################################################################################################
@@ -1023,3 +1026,73 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
1023
1026
assert utils .get_consumer_nodes (converted_model , "const_scalar" )[0 ].op_type == "Add"
1024
1027
assert len (utils .get_consumer_nodes (converted_model , "const_array" )) == 1
1025
1028
assert utils .get_consumer_nodes (converted_model , "const_array" )[0 ].op_type == "Add"
1029
+
1030
+
1031
+ @pytest .fixture
1032
+ def model_with_casted_input_to_output ():
1033
+ """Create a model with an output produced by a Cast node."""
1034
+ # Create input and outputs
1035
+ x = helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [2 , 3 ])
1036
+ y1 = helper .make_tensor_value_info ("Y1" , TensorProto .FLOAT , [2 , 3 ]) # Intermediate output
1037
+ y2 = helper .make_tensor_value_info ("Y2" , TensorProto .FLOAT , [2 , 3 ]) # Final output
1038
+
1039
+ # Create constant value
1040
+ const = np .array ([[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]], dtype = np .float32 )
1041
+
1042
+ # Create constant node
1043
+ const_node = helper .make_node (
1044
+ "Constant" ,
1045
+ [],
1046
+ ["const" ],
1047
+ name = "const" ,
1048
+ value = numpy_helper .from_array (const , name = "const_value" ),
1049
+ )
1050
+
1051
+ # Create computation nodes
1052
+ add1 = helper .make_node ("Add" , ["X" , "const" ], ["add1_out" ], name = "add1" )
1053
+ add2 = helper .make_node ("Add" , ["add1_out" , "const" ], ["Y2" ], name = "add2" )
1054
+
1055
+ # Create cast node that feeds directly from input to output
1056
+ cast_input = helper .make_node ("Cast" , ["X" ], ["Y1" ], name = "cast_input" , to = TensorProto .FLOAT )
1057
+
1058
+ graph = helper .make_graph (
1059
+ [const_node , add1 , add2 , cast_input ],
1060
+ "model_with_casted_output" ,
1061
+ [x ],
1062
+ [y1 , y2 ],
1063
+ [],
1064
+ )
1065
+
1066
+ model = helper .make_model (graph , producer_name = "model_with_casted_output" )
1067
+ model .opset_import [0 ].version = 20
1068
+ model .ir_version = 10
1069
+ onnx .checker .check_model (model )
1070
+
1071
+ model = onnx_utils .infer_shapes (model )
1072
+ value_info_map , initializer_map , node_to_init_map = utils .setup_mappings (model )
1073
+
1074
+ return model , value_info_map , initializer_map , node_to_init_map
1075
+
1076
+
1077
+ @pytest .mark .parametrize ("low_precision_type" , ["fp16" , "bf16" ])
1078
+ @pytest .mark .parametrize ("keep_io_types" , [True , False ])
1079
+ def test_casted_input_to_output_model (
1080
+ model_with_casted_input_to_output , low_precision_type , keep_io_types
1081
+ ):
1082
+ model , value_info_map , initializer_map , node_to_init_map = model_with_casted_input_to_output
1083
+
1084
+ converter = PrecisionConverter (
1085
+ model ,
1086
+ value_info_map ,
1087
+ initializer_map ,
1088
+ node_to_init_map ,
1089
+ keep_io_types = keep_io_types ,
1090
+ low_precision_type = low_precision_type ,
1091
+ min_opset = 22 if low_precision_type == "bf16" else 13 ,
1092
+ max_ir_version = LATEST_IR_VERSION_SUPPORTED_BY_ORT ,
1093
+ trt_plugins = [],
1094
+ )
1095
+ converted_model = converter .convert (
1096
+ high_precision_nodes = ["cast_input" ], low_precision_nodes = ["add1" , "add2" ]
1097
+ )
1098
+ onnx .checker .check_model (converted_model )
0 commit comments