Skip to content

Commit 4ba449b

Browse files
authored
Merge pull request #431 from lucienwang1009/equal_more_dtype
support more dtype for Equal
2 parents d3f2d1e + 0599d24 commit 4ba449b

File tree

3 files changed

+62
-20
lines changed

3 files changed

+62
-20
lines changed

tests/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def check_onnxruntime_incompatibility(op):
197197
"AveragePool": 7, # AveragePool-1
198198
"Div": 7, # Div-1, Div-6
199199
"Elu": 6, # Elu-1
200+
"Equal": 7, # Equal-1
200201
"Exp": 6, # Exp-1
201202
"Greater": 7, # Greater-1
202203
"Less": 7, # Less-1

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,26 @@ def test_less_unsupport_type(self):
672672
_ = tf.identity(mi, name=_TFOUTPUT)
673673
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
674674

675+
@check_onnxruntime_incompatibility("Equal")
676+
def test_equal(self):
677+
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
678+
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
679+
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
680+
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
681+
mi = tf.equal(x1, x2)
682+
_ = tf.identity(mi, name=_TFOUTPUT)
683+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
684+
685+
tf.reset_default_graph()
686+
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
687+
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
688+
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
689+
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
690+
mi = tf.equal(x1, x2)
691+
_ = tf.identity(mi, name=_TFOUTPUT)
692+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
693+
694+
675695
def test_sequeeze_no_axis_specified(self):
676696
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
677697
x = tf.placeholder(tf.float32, [2, 2, 1], name=_TFINPUT)

tf2onnx/onnx_opset/logical.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import logging
1313

14-
from onnx import onnx_pb
14+
from onnx import TensorProto
1515
from tf2onnx import utils
1616
from tf2onnx.handler import tf_op
1717
from tf2onnx.onnx_opset import common
@@ -21,22 +21,17 @@
2121

2222
# pylint: disable=unused-argument,missing-docstring
2323

24-
def logical_compare_op(ctx, node, **kwargs):
25-
# T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
26-
# T2 output = Less(T1 x, T1 y), T2=tensor(bool)
27-
# Great/Less in opset7 only supports limited types, insert Cast if needed
28-
if ctx.opset < 9:
29-
supported_dtypes = [
30-
onnx_pb.TensorProto.FLOAT,
31-
onnx_pb.TensorProto.FLOAT16,
32-
onnx_pb.TensorProto.DOUBLE
33-
]
34-
target_dtype = onnx_pb.TensorProto.FLOAT
24+
def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
25+
is_support = True
26+
for inp in node.input:
27+
if graph.get_dtype(inp) not in supported_dtypes:
28+
is_support = False
29+
break
30+
if not is_support:
3531
for inp in node.input:
36-
if ctx.get_dtype(inp) not in supported_dtypes:
37-
inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=target_dtype)
38-
ctx.copy_shape(inp, inp_cast.output[0])
39-
ctx.set_dtype(inp_cast.output[0], target_dtype)
32+
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=target_dtype)
33+
graph.copy_shape(inp, inp_cast.output[0])
34+
graph.set_dtype(inp_cast.output[0], target_dtype)
4035

4136

4237
@tf_op(["LogicalNot", "NotEqual"], onnx_op="Not")
@@ -46,30 +41,56 @@ def version_4(cls, ctx, node, **kwargs):
4641
pass
4742

4843

49-
@tf_op(["Equal", "Greater", "Less"])
5044
@tf_op("LogicalAnd", onnx_op="And")
5145
@tf_op("LogicalOr", onnx_op="Or")
5246
class BroadcastOp(common.BroadcastOp):
5347
pass
5448

5549

50+
@tf_op("Equal")
51+
class Equal:
52+
@classmethod
53+
def version_4(cls, ctx, node, **kwargs):
54+
common.BroadcastOp.version_4(ctx, node, **kwargs)
55+
56+
@classmethod
57+
def version_7(cls, ctx, node, **kwargs):
58+
# T2 output = Equal(T1, x, T1 y), T1 \in {bool, int32, int64}
59+
supported_dtypes = [
60+
TensorProto.BOOL,
61+
TensorProto.INT32,
62+
TensorProto.INT64
63+
]
64+
target_dtype = TensorProto.INT32
65+
_add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype)
66+
67+
5668
@tf_op(["Greater", "Less"])
57-
class Greater:
69+
class GreaterLess:
5870
@classmethod
5971
def version_4(cls, ctx, node, **kwargs):
6072
common.BroadcastOp.version_4(ctx, node, **kwargs)
6173

6274
@classmethod
6375
def version_7(cls, ctx, node, **kwargs):
64-
logical_compare_op(ctx, node, **kwargs)
76+
# T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
77+
# T2 output = Less(T1 x, T1 y), T2=tensor(bool)
78+
# Great/Less in opset7 only supports limited types, insert Cast if needed
79+
supported_dtypes = [
80+
TensorProto.FLOAT,
81+
TensorProto.FLOAT16,
82+
TensorProto.DOUBLE
83+
]
84+
target_dtype = TensorProto.FLOAT
85+
_add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype)
6586

6687

6788
@tf_op("GreaterEqual", onnx_op="Less")
6889
@tf_op("LessEqual", onnx_op="Greater")
6990
class GreaterLessEqual:
7091
@classmethod
7192
def version_7(cls, ctx, node, **kwargs):
72-
logical_compare_op(ctx, node, **kwargs)
93+
GreaterLess.version_7(ctx, node, **kwargs)
7394
output_name = node.output[0]
7495
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
7596
ctx.copy_shape(output_name, new_node.output[0])

0 commit comments

Comments
 (0)