11
11
12
12
import logging
13
13
14
- from onnx import onnx_pb
14
+ from onnx import TensorProto
15
15
from tf2onnx import utils
16
16
from tf2onnx .handler import tf_op
17
17
from tf2onnx .onnx_opset import common
21
21
22
22
# pylint: disable=unused-argument,missing-docstring
23
23
24
- def logical_compare_op (ctx , node , ** kwargs ):
25
- # T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
26
- # T2 output = Less(T1 x, T1 y), T2=tensor(bool)
27
- # Great/Less in opset7 only supports limited types, insert Cast if needed
28
- if ctx .opset < 9 :
29
- supported_dtypes = [
30
- onnx_pb .TensorProto .FLOAT ,
31
- onnx_pb .TensorProto .FLOAT16 ,
32
- onnx_pb .TensorProto .DOUBLE
33
- ]
34
- target_dtype = onnx_pb .TensorProto .FLOAT
24
+ def _add_cast_to_inputs (graph , node , supported_dtypes , target_dtype ):
25
+ is_support = True
26
+ for inp in node .input :
27
+ if graph .get_dtype (inp ) not in supported_dtypes :
28
+ is_support = False
29
+ break
30
+ if not is_support :
35
31
for inp in node .input :
36
- if ctx .get_dtype (inp ) not in supported_dtypes :
37
- inp_cast = ctx .insert_new_node_on_input (node , "Cast" , inp , to = target_dtype )
38
- ctx .copy_shape (inp , inp_cast .output [0 ])
39
- ctx .set_dtype (inp_cast .output [0 ], target_dtype )
32
+ inp_cast = graph .insert_new_node_on_input (node , "Cast" , inp , to = target_dtype )
33
+ graph .copy_shape (inp , inp_cast .output [0 ])
34
+ graph .set_dtype (inp_cast .output [0 ], target_dtype )
40
35
41
36
42
37
@tf_op (["LogicalNot" , "NotEqual" ], onnx_op = "Not" )
@@ -46,30 +41,56 @@ def version_4(cls, ctx, node, **kwargs):
46
41
pass
47
42
48
43
49
- @tf_op (["Equal" , "Greater" , "Less" ])
50
44
@tf_op ("LogicalAnd" , onnx_op = "And" )
51
45
@tf_op ("LogicalOr" , onnx_op = "Or" )
52
46
class BroadcastOp (common .BroadcastOp ):
53
47
pass
54
48
55
49
50
+ @tf_op ("Equal" )
51
+ class Equal :
52
+ @classmethod
53
+ def version_4 (cls , ctx , node , ** kwargs ):
54
+ common .BroadcastOp .version_4 (ctx , node , ** kwargs )
55
+
56
+ @classmethod
57
+ def version_7 (cls , ctx , node , ** kwargs ):
58
+ # T2 output = Equal(T1, x, T1 y), T1 \in {bool, int32, int64}
59
+ supported_dtypes = [
60
+ TensorProto .BOOL ,
61
+ TensorProto .INT32 ,
62
+ TensorProto .INT64
63
+ ]
64
+ target_dtype = TensorProto .INT32
65
+ _add_cast_to_inputs (ctx , node , supported_dtypes , target_dtype )
66
+
67
+
56
68
@tf_op (["Greater" , "Less" ])
57
- class Greater :
69
+ class GreaterLess :
58
70
@classmethod
59
71
def version_4 (cls , ctx , node , ** kwargs ):
60
72
common .BroadcastOp .version_4 (ctx , node , ** kwargs )
61
73
62
74
@classmethod
63
75
def version_7 (cls , ctx , node , ** kwargs ):
64
- logical_compare_op (ctx , node , ** kwargs )
76
+ # T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
77
+ # T2 output = Less(T1 x, T1 y), T2=tensor(bool)
78
+ # Great/Less in opset7 only supports limited types, insert Cast if needed
79
+ supported_dtypes = [
80
+ TensorProto .FLOAT ,
81
+ TensorProto .FLOAT16 ,
82
+ TensorProto .DOUBLE
83
+ ]
84
+ target_dtype = TensorProto .FLOAT
85
+ _add_cast_to_inputs (ctx , node , supported_dtypes , target_dtype )
65
86
66
87
67
88
@tf_op ("GreaterEqual" , onnx_op = "Less" )
68
89
@tf_op ("LessEqual" , onnx_op = "Greater" )
69
90
class GreaterLessEqual :
70
91
@classmethod
71
92
def version_7 (cls , ctx , node , ** kwargs ):
72
- logical_compare_op (ctx , node , ** kwargs )
93
+ GreaterLess . version_7 (ctx , node , ** kwargs )
73
94
output_name = node .output [0 ]
74
95
new_node = ctx .insert_new_node_on_output ("Not" , output_name , name = utils .make_name (node .name ))
75
96
ctx .copy_shape (output_name , new_node .output [0 ])
0 commit comments