@@ -1633,6 +1633,11 @@ def logical_compare_op(ctx, node, name, args):
1633
1633
ctx .copy_shape (inp , inp_cast .output [0 ])
1634
1634
ctx .set_dtype (inp_cast .output [0 ], target_dtype )
1635
1635
1636
+ def logical_compareeq_op (ctx , node , name , args ):
1637
+ logical_compare_op (ctx , node , name , [])
1638
+ ctx .insert_new_node_on_output ("Not" , node .output [0 ], name = utils .make_name (name ),
1639
+ shapes = ctx .get_shape (node .output [0 ]), dtypes = ctx .get_dtype (node .output [0 ]))
1640
+
1636
1641
1637
1642
def where_op (ctx , node , name , args ):
1638
1643
# T_y output = Where(T_x condition), return indices of elements whose value are True
@@ -1777,6 +1782,8 @@ def where_op(ctx, node, name, args):
1777
1782
"FusedBatchNormV2" : (fused_batchnorm_op7 , []),
1778
1783
"Greater" : (logical_compare_op , []),
1779
1784
"Less" : (logical_compare_op , []),
1785
+ "GreaterEqual" : (logical_compareeq_op , ["Less" ]),
1786
+ "LessEqual" : (logical_compareeq_op , ["Greater" ]),
1780
1787
"LogicalAnd" : (broadcast_op7 , ["And" ]),
1781
1788
"LogicalOr" : (broadcast_op7 , ["Or" ]),
1782
1789
"MatrixBandPart" : (matrixbandpart_op , []),
@@ -1812,9 +1819,7 @@ def where_op(ctx, node, name, args):
1812
1819
"Cosh" : (direct_op , []),
1813
1820
"Erf" : (direct_op , []),
1814
1821
"Fill" : (fill_op , []),
1815
- "Greater" : (logical_compare_op , []),
1816
1822
"IsNan" : (direct_op , ["IsNaN" ]),
1817
- "Less" : (logical_compare_op , []),
1818
1823
"ResizeBilinear" : (upsample_op9 , ["Upsample" , "linear" ]),
1819
1824
"ResizeNearestNeighbor" : (upsample_op9 , ["Upsample" , "nearest" ]),
1820
1825
"Sign" : (sign_op9 , []),
@@ -2105,50 +2110,6 @@ def rewrite_constant_fold(g, ops):
2105
2110
return ops
2106
2111
2107
2112
2108
- def rewrite_logical_compare_with_equal (g , ops ):
2109
- patterns = {"GreaterEqual" : "Greater" ,
2110
- "LessEqual" : "Less" }
2111
- for p in patterns :
2112
- pattern = OpTypePattern (p , name = 'compare_with_equal' )
2113
- compare_name = patterns [p ]
2114
- matcher = GraphMatcher (pattern )
2115
- match_results = list (matcher .match_ops (ops ))
2116
- for match in match_results :
2117
- nodes_to_append = []
2118
- compare_e_op = match .get_op ('compare_with_equal' )
2119
- data_type = g .get_dtype (compare_e_op .input [0 ])
2120
- compare_input_ids = compare_e_op .input
2121
- need_cast = g .opset < 9 and data_type not in (onnx_pb .TensorProto .FLOAT16 ,
2122
- onnx_pb .TensorProto .FLOAT ,
2123
- onnx_pb .TensorProto .DOUBLE )
2124
- if need_cast :
2125
- compare_input_ids = []
2126
- for input_id in compare_e_op .input :
2127
- cast_node = g .make_node ("Cast" , [input_id ], op_name_scope = compare_e_op .name ,
2128
- attr = {"to" : onnx_pb .TensorProto .FLOAT }, shapes = [g .get_shape (input_id )],
2129
- dtypes = [onnx_pb .TensorProto .FLOAT ])
2130
- compare_input_ids .append (cast_node .output [0 ])
2131
- nodes_to_append .append (cast_node )
2132
-
2133
- g_node = g .make_node (compare_name , compare_input_ids , op_name_scope = compare_e_op .name ,
2134
- dtypes = [onnx_pb .TensorProto .BOOL ])
2135
- set_shape_from_inputs_broadcast (g , compare_input_ids , g_node .output [0 ])
2136
- new_shape = g .get_shape (g_node .output [0 ])
2137
- nodes_to_append .append (g_node )
2138
-
2139
- e_node = g .make_node ("Equal" , compare_e_op .input , op_name_scope = compare_e_op .name ,
2140
- shapes = [new_shape ], dtypes = [onnx_pb .TensorProto .BOOL ])
2141
- nodes_to_append .append (e_node )
2142
-
2143
- compare_e_op .type = "LogicalOr"
2144
- compare_e_op .input [0 ] = g_node .output [0 ]
2145
- compare_e_op .input [1 ] = e_node .output [0 ]
2146
- g .set_dtype (compare_e_op .output [0 ], onnx_pb .TensorProto .BOOL )
2147
- g .set_shape (compare_e_op .output [0 ], new_shape )
2148
- ops .extend (nodes_to_append )
2149
- return ops
2150
-
2151
-
2152
2113
def rewrite_incomplete_type_support (g , ops , impacted_ops ):
2153
2114
"""
2154
2115
for ops that have inclomplete type support, insert casts.
@@ -2459,7 +2420,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2459
2420
rewrite_leakyrelu , rewrite_conv2d_with_pad ,
2460
2421
rewrite_single_direction_lstm , rewrite_bi_direction_lstm ,
2461
2422
rewrite_single_direction_gru , rewrite_single_direction_grublock ,
2462
- rewrite_bi_direction_gru , rewrite_logical_compare_with_equal ,
2423
+ rewrite_bi_direction_gru ,
2463
2424
rewrite_custom_rnn_cell , rewrite_generic_loop , rewrite_cond
2464
2425
]
2465
2426
0 commit comments