Skip to content

Commit fa18093

Browse files
committed
remove dup
1 parent 2f5852c commit fa18093

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

tf2onnx/onnx_opset/logical.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,28 +110,17 @@ def version_7(cls, ctx, node, **kwargs):
110110
target_dtype = TensorProto.FLOAT
111111
_add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype)
112112

113-
114-
@tf_op("GreaterEqual", onnx_op="Less")
115-
@tf_op("LessEqual", onnx_op="Greater")
113+
@tf_op(["GreaterEqual", "LessEqual"])
116114
class GreaterLessEqual:
117115
@classmethod
118116
def version_7(cls, ctx, node, **kwargs):
119117
GreaterLess.version_7(ctx, node, **kwargs)
120118
output_name = node.output[0]
119+
node.op.op_type = "Less" if node.op.op_type == "GreaterEqual" else "Greater"
121120
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
122121
ctx.copy_shape(output_name, new_node.output[0])
123122
ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))
124123

125-
126-
@tf_op("GreaterEqual", onnx_op="GreaterOrEqual")
127-
class GreaterEqual:
128-
@classmethod
129-
def version_12(cls, ctx, node, **kwargs):
130-
pass
131-
132-
133-
@tf_op("LessEqual", onnx_op="LessOrEqual")
134-
class LessEqual:
135124
@classmethod
136125
def version_12(cls, ctx, node, **kwargs):
137-
pass
126+
node.op.op_type = "GreaterOrEqual" if node.op.op_type == "GreaterEqual" else "LessOrEqual"

0 commit comments

Comments
 (0)