@@ -110,28 +110,17 @@ def version_7(cls, ctx, node, **kwargs):
110
110
target_dtype = TensorProto .FLOAT
111
111
_add_cast_to_inputs (ctx , node , supported_dtypes , target_dtype )
112
112
113
-
114
- @tf_op ("GreaterEqual" , onnx_op = "Less" )
115
- @tf_op ("LessEqual" , onnx_op = "Greater" )
113
+ @tf_op (["GreaterEqual" , "LessEqual" ])
116
114
class GreaterLessEqual :
117
115
@classmethod
118
116
def version_7 (cls , ctx , node , ** kwargs ):
119
117
GreaterLess .version_7 (ctx , node , ** kwargs )
120
118
output_name = node .output [0 ]
119
+ node .op .op_type = "Less" if node .op .op_type == "GreaterEqual" else "Greater"
121
120
new_node = ctx .insert_new_node_on_output ("Not" , output_name , name = utils .make_name (node .name ))
122
121
ctx .copy_shape (output_name , new_node .output [0 ])
123
122
ctx .set_dtype (new_node .output [0 ], ctx .get_dtype (output_name ))
124
123
125
-
126
- @tf_op ("GreaterEqual" , onnx_op = "GreaterOrEqual" )
127
- class GreaterEqual :
128
- @classmethod
129
- def version_12 (cls , ctx , node , ** kwargs ):
130
- pass
131
-
132
-
133
- @tf_op ("LessEqual" , onnx_op = "LessOrEqual" )
134
- class LessEqual :
135
124
@classmethod
136
125
def version_12 (cls , ctx , node , ** kwargs ):
137
- pass
126
+ node . op . op_type = "GreaterOrEqual" if node . op . op_type == "GreaterEqual" else "LessOrEqual"
0 commit comments