We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 7d5faea + 42c8f83 commit 7b9f056Copy full SHA for 7b9f056
tests/test_backend.py
@@ -808,6 +808,16 @@ def test_less_unsupport_type(self):
808
_ = tf.identity(mi, name=_TFOUTPUT)
809
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
810
811
+ @check_opset_min_version(11, "Equal")
812
+ def test_equal_float(self):
813
+ x_val1 = np.array([0., 1., 2., 3., 4., -1., -2], dtype=np.float32)
814
+ x_val2 = np.array([0., 1., 2.1, 3.5, 4.6, -1.1, -2.9], dtype=np.float32)
815
+ x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
816
+ x2 = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
817
+ mi = tf.equal(x1, x2)
818
+ _ = tf.identity(mi, name=_TFOUTPUT)
819
+ self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
820
+
821
def test_equal(self):
822
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
823
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
tf2onnx/onnx_opset/logical.py
@@ -79,6 +79,17 @@ def version_7(cls, ctx, node, **kwargs):
79
ctx.copy_shape(output_name, not_node.output[0])
80
ctx.copy_dtype(output_name, not_node.output[0])
81
82
+ @classmethod
83
+ def version_11(cls, ctx, node, **kwargs):
84
+ # starting with opset-11, equal supports all types
85
+ need_not = node.type == "NotEqual"
86
+ if need_not:
87
+ node.type = "Equal"
88
+ output_name = node.output[0]
89
+ not_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
90
+ ctx.copy_shape(output_name, not_node.output[0])
91
+ ctx.copy_dtype(output_name, not_node.output[0])
92
93
94
@tf_op(["Greater", "Less"])
95
class GreaterLess:
0 commit comments