Skip to content

Commit 42c8f83

Browse files
committed
opset-11 supports equal for all types
1 parent f7a95c5 commit 42c8f83

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,16 @@ def test_less_unsupport_type(self):
808808
_ = tf.identity(mi, name=_TFOUTPUT)
809809
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
810810

811+
@check_opset_min_version(11, "Equal")
812+
def test_equal_float(self):
813+
x_val1 = np.array([0., 1., 2., 3., 4., -1., -2], dtype=np.float32)
814+
x_val2 = np.array([0., 1., 2.1, 3.5, 4.6, -1.1, -2.9], dtype=np.float32)
815+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
816+
x2 = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
817+
mi = tf.equal(x1, x2)
818+
_ = tf.identity(mi, name=_TFOUTPUT)
819+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
820+
811821
def test_equal(self):
812822
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
813823
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))

tf2onnx/onnx_opset/logical.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ def version_7(cls, ctx, node, **kwargs):
7979
ctx.copy_shape(output_name, not_node.output[0])
8080
ctx.copy_dtype(output_name, not_node.output[0])
8181

82+
@classmethod
83+
def version_11(cls, ctx, node, **kwargs):
84+
# starting with opset-11, equal supports all types
85+
need_not = node.type == "NotEqual"
86+
if need_not:
87+
node.type = "Equal"
88+
output_name = node.output[0]
89+
not_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
90+
ctx.copy_shape(output_name, not_node.output[0])
91+
ctx.copy_dtype(output_name, not_node.output[0])
92+
8293

8394
@tf_op(["Greater", "Less"])
8495
class GreaterLess:

0 commit comments

Comments
 (0)