@@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
1023
1023
assert utils .get_consumer_nodes (converted_model , "const_scalar" )[0 ].op_type == "Add"
1024
1024
assert len (utils .get_consumer_nodes (converted_model , "const_array" )) == 1
1025
1025
assert utils .get_consumer_nodes (converted_model , "const_array" )[0 ].op_type == "Add"
1026
+
1027
+
1028
+ @pytest .fixture
1029
+ def model_with_multiple_output_node_casted_to_output ():
1030
+ """Create a model with a Cast node connecting a consumer with multiple outputs to a graph output."""
1031
+ # Create inputs and outputs
1032
+ x1 = helper .make_tensor_value_info ("X1" , TensorProto .FLOAT , [1 , 2 , 16 , 16 ])
1033
+ x2 = helper .make_tensor_value_info ("X2" , TensorProto .FLOAT , [1 , 3 , 16 , 16 ])
1034
+ x3 = helper .make_tensor_value_info ("X3" , TensorProto .FLOAT , [1 , 4 , 16 , 16 ])
1035
+ y1 = helper .make_tensor_value_info ("Y1" , TensorProto .FLOAT , [1 , 5 , 16 , 16 ])
1036
+ y2 = helper .make_tensor_value_info ("Y2" , TensorProto .FLOAT , [1 , 9 , 16 , 16 ])
1037
+
1038
+ # Create computation nodes
1039
+ concat_1_node = helper .make_node (
1040
+ "Concat" ,
1041
+ ["X1" , "X2" ],
1042
+ ["concat_1_out" ],
1043
+ name = "concat_1" ,
1044
+ axis = 1 ,
1045
+ )
1046
+ concat_2_node = helper .make_node (
1047
+ "Concat" ,
1048
+ ["concat_1_out" , "X3" ],
1049
+ ["Y2" ],
1050
+ name = "concat_2" ,
1051
+ axis = 1 ,
1052
+ )
1053
+
1054
+ # Create a Cast node between 'concat_1' and the graph output
1055
+ cast_node = helper .make_node (
1056
+ "Cast" ,
1057
+ ["concat_1_out" ],
1058
+ ["Y1" ],
1059
+ name = "cast_0" ,
1060
+ to = TensorProto .FLOAT ,
1061
+ )
1062
+
1063
+ graph = helper .make_graph (
1064
+ [concat_1_node , concat_2_node , cast_node ],
1065
+ "model_with_multiple_output_node_casted_to_output" ,
1066
+ [x1 , x2 , x3 ],
1067
+ [y1 , y2 ],
1068
+ [],
1069
+ )
1070
+
1071
+ model = helper .make_model (
1072
+ graph , producer_name = "model_with_multiple_output_node_casted_to_output"
1073
+ )
1074
+ model .opset_import [0 ].version = 20
1075
+ model .ir_version = 10
1076
+ onnx .checker .check_model (model )
1077
+
1078
+ model = onnx_utils .infer_shapes (model )
1079
+ value_info_map , initializer_map , node_to_init_map = utils .setup_mappings (model )
1080
+
1081
+ return model , value_info_map , initializer_map , node_to_init_map
1082
+
1083
+
1084
+ @pytest .mark .parametrize ("low_precision_type" , ["fp16" , "bf16" ])
1085
+ def test_multiple_output_node_casted_to_output (
1086
+ model_with_multiple_output_node_casted_to_output , low_precision_type
1087
+ ):
1088
+ model , value_info_map , initializer_map , node_to_init_map = (
1089
+ model_with_multiple_output_node_casted_to_output
1090
+ )
1091
+
1092
+ converter = PrecisionConverter (
1093
+ model ,
1094
+ value_info_map ,
1095
+ initializer_map ,
1096
+ node_to_init_map ,
1097
+ keep_io_types = True ,
1098
+ low_precision_type = low_precision_type ,
1099
+ )
1100
+ converted_model = converter .convert (
1101
+ high_precision_nodes = [], low_precision_nodes = ["concat_1" , "concat_2" ]
1102
+ )
1103
+ onnx .checker .check_model (converted_model )
0 commit comments