@@ -340,29 +340,6 @@ def reshape_op5(ctx, node, name, args):
340
340
return [input_cast ] + nodes
341
341
342
342
343
- def less_op7 (ctx , node , name , args ):
344
- """Elementwise Ops with Less-7 flag."""
345
- nodes = [node ]
346
- input1_dtype = ctx .get_dtype (node .input [0 ])
347
- input2_dtype = ctx .get_dtype (node .input [1 ])
348
- target_dtype = onnx_pb .TensorProto .FLOAT
349
- need_case_1 = input1_dtype != target_dtype
350
- if need_case_1 :
351
- input1_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
352
- input1_cast .set_attr ("to" , target_dtype )
353
- ctx .copy_shape (node .output [0 ], input1_cast .output [0 ])
354
- ctx .set_shape (input1_cast .output [0 ], target_dtype )
355
- nodes .insert (0 , input1_cast )
356
-
357
- input2_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ])
358
- input2_cast .set_attr ("to" , target_dtype )
359
- ctx .copy_shape (node .output [0 ], input2_cast .output [0 ])
360
- ctx .set_shape (input2_cast .output [0 ], target_dtype )
361
- nodes .insert (0 , input2_cast )
362
-
363
- return nodes
364
-
365
-
366
343
NCHW_TO_NHWC = [0 , 2 , 3 , 1 ]
367
344
NHWC_TO_NCHW = [0 , 3 , 1 , 2 ]
368
345
HWCN_TO_NCHW = [3 , 2 , 0 , 1 ]
@@ -982,23 +959,6 @@ def expanddims_op(ctx, node, name, args):
982
959
raise ValueError ("non-const dim is not supported" )
983
960
984
961
985
- def greater_op7 (ctx , node , name , args ):
986
- nodes = []
987
- supported_types = [
988
- onnx_pb .TensorProto .FLOAT ,
989
- onnx_pb .TensorProto .FLOAT16 ,
990
- onnx_pb .TensorProto .DOUBLE
991
- ]
992
- for inp in node .input :
993
- if ctx .get_dtype (inp ) not in supported_types :
994
- inp_cast = ctx .insert_new_node_on_input (node , "Cast" , inp , to = onnx_pb .TensorProto .FLOAT )
995
- ctx .copy_shape (inp , inp_cast .output [0 ])
996
- ctx .set_dtype (inp_cast .output [0 ], onnx_pb .TensorProto .FLOAT )
997
- nodes .append (inp_cast )
998
- nodes .append (broadcast_op7 (ctx , node , name , args ))
999
- return nodes
1000
-
1001
-
1002
962
def expanddims_op7 (ctx , node , name , args ):
1003
963
# T output = ExpandDims(T input, Tdim dim, @type Tdim), dim is 0-D scalar.
1004
964
# T reshaped = Reshape-5(T data, int64 shape)
@@ -1717,6 +1677,27 @@ def softmax_op(ctx, node, name, args):
1717
1677
return node
1718
1678
1719
1679
1680
+ def logical_compare_op (ctx , node , name , args ):
1681
+ # T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
1682
+ # T2 output = Less(T1 x, T1 y), T2=tensor(bool)
1683
+ nodes = [node ]
1684
+ # Great/Less in opset7 only supports limited types, insert Cast if needed
1685
+ if ctx .opset < 9 :
1686
+ supported_dtypes = [
1687
+ onnx_pb .TensorProto .FLOAT ,
1688
+ onnx_pb .TensorProto .FLOAT16 ,
1689
+ onnx_pb .TensorProto .DOUBLE
1690
+ ]
1691
+ target_dtype = onnx_pb .TensorProto .FLOAT
1692
+ for inp in node .input :
1693
+ if ctx .get_dtype (inp ) not in supported_dtypes :
1694
+ inp_cast = ctx .insert_new_node_on_input (node , "Cast" , inp , to = target_dtype )
1695
+ ctx .copy_shape (inp , inp_cast .output [0 ])
1696
+ ctx .set_dtype (inp_cast .output [0 ], target_dtype )
1697
+ nodes .append (inp_cast )
1698
+ return nodes
1699
+
1700
+
1720
1701
# map tensorflow ops to onnx ops. The format below is
1721
1702
# "TFOP": func_to_map, ["OnnxOp", ...]
1722
1703
#
@@ -1845,8 +1826,8 @@ def softmax_op(ctx, node, name, args):
1845
1826
"FloorMod" : (floormod_op , []),
1846
1827
"FusedBatchNorm" : (fused_batchnorm_op7 , []),
1847
1828
"FusedBatchNormV2" : (fused_batchnorm_op7 , []),
1848
- "Greater" : (greater_op7 , []),
1849
- "Less" : (less_op7 , []),
1829
+ "Greater" : (logical_compare_op , []),
1830
+ "Less" : (logical_compare_op , []),
1850
1831
"LogicalAnd" : (broadcast_op7 , ["And" ]),
1851
1832
"LogicalOr" : (broadcast_op7 , ["Or" ]),
1852
1833
"MatrixBandPart" : (matrixbandpart_op , []),
@@ -1884,6 +1865,8 @@ def softmax_op(ctx, node, name, args):
1884
1865
"Asinh" : (direct_op , []),
1885
1866
"Acosh" : (direct_op , []),
1886
1867
"Atanh" : (direct_op , []),
1868
+ "Greater" : (logical_compare_op , []),
1869
+ "Less" : (logical_compare_op , []),
1887
1870
"ResizeBilinear" : (upsample_op9 , ["Upsample" , "linear" ]),
1888
1871
"ResizeNearestNeighbor" : (upsample_op9 , ["Upsample" , "nearest" ]),
1889
1872
}
@@ -2178,9 +2161,9 @@ def rewrite_logical_compare_with_equal(g, ops):
2178
2161
compare_e_op = match .get_op ('compare_with_equal' )
2179
2162
data_type = g .get_dtype (compare_e_op .input [0 ])
2180
2163
compare_input_ids = compare_e_op .input
2181
- need_cast = data_type not in (onnx_pb .TensorProto .FLOAT16 ,
2182
- onnx_pb .TensorProto .FLOAT ,
2183
- onnx_pb .TensorProto .DOUBLE )
2164
+ need_cast = g . opset < 9 and data_type not in (onnx_pb .TensorProto .FLOAT16 ,
2165
+ onnx_pb .TensorProto .FLOAT ,
2166
+ onnx_pb .TensorProto .DOUBLE )
2184
2167
if need_cast :
2185
2168
compare_input_ids = []
2186
2169
for input_id in compare_e_op .input :
0 commit comments