Skip to content

Commit 09c7c4c

Browse files
committed
simplify GRE and LRE
1 parent ba1b0ad commit 09c7c4c

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

tests/test_backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3153,7 +3153,16 @@ def test_einsum(self):
31533153
def func(x, y):
31543154
ret = tf.einsum("i,j->ij", x, y)
31553155
return tf.identity(ret, name=_TFOUTPUT)
3156-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3156+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3157+
3158+
@check_opset_min_version(12)
3159+
def test_compare(self):
3160+
x_val = np.random.random([10, 20]).astype(np.float32)
3161+
y_val = np.random.random([10, 20]).astype(np.float32)
3162+
def func(x, y):
3163+
return tf.math.less_equal(x, y, name=_TFOUTPUT), \
3164+
tf.math.greater_equal(x, y, name=_TFOUTPUT1)
3165+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: y_val})
31573166

31583167

31593168
if __name__ == '__main__':

tf2onnx/onnx_opset/logical.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,17 @@ def version_7(cls, ctx, node, **kwargs):
121121
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
122122
ctx.copy_shape(output_name, new_node.output[0])
123123
ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))
124+
125+
126+
@tf_op("GreaterEqual", onnx_op="GreaterOrEqual")
127+
class GreaterEqual:
128+
@classmethod
129+
def version_12(cls, ctx, node, **kwargs):
130+
pass
131+
132+
133+
@tf_op("LessEqual", onnx_op="LessOrEqual")
134+
class LessEqual:
135+
@classmethod
136+
def version_12(cls, ctx, node, **kwargs):')
137+
pass

tf2onnx/onnx_opset/math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ class SquaredDistance:
561561
def version_12(cls, ctx, node, **kwargs):
562562
node.attr["reduction"] = "none"
563563

564+
564565
@tf_op("Einsum")
565566
class Einsum:
566567
@classmethod
@@ -570,3 +571,4 @@ def version_12(cls, ctx, node, **kwargs):
570571

571572

572573

574+

0 commit comments

Comments
 (0)