Skip to content

Commit 78a09b5

Browse files
authored
update greater and less for opset9 (#295)
1 parent 5081424 commit 78a09b5

File tree

2 files changed

+50
-46
lines changed

2 files changed

+50
-46
lines changed

tests/test_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def onnxruntime_check(op):
117117
"Div": 7, # Div-1, Div-6
118118
"Elu": 6, # Elu-1
119119
"Exp": 6, # Exp-1
120-
"Greater": 7, # Greater-1
120+
"Greater": 7, # Greater-7
121+
"Less": 7, # Less-7
121122
"Log": 6, # Log-1
122123
"Max": 6, # Max-1
123124
"Min": 6, # Min-1
@@ -659,6 +660,26 @@ def test_greater_unsupport_type(self):
659660
_ = tf.identity(mi, name=_TFOUTPUT)
660661
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
661662

663+
@unittest.skipIf(*onnxruntime_check("Less"))
664+
def test_less(self):
665+
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
666+
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
667+
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
668+
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
669+
mi = tf.less(x1, x2)
670+
_ = tf.identity(mi, name=_TFOUTPUT)
671+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
672+
673+
@unittest.skipIf(*onnxruntime_check("Less"))
674+
def test_less_unsupport_type(self):
675+
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
676+
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
677+
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
678+
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
679+
mi = tf.less(x1, x2)
680+
_ = tf.identity(mi, name=_TFOUTPUT)
681+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
682+
662683
def test_sequeeze_no_axis_specified(self):
663684
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
664685
x = tf.placeholder(tf.float32, [2, 2, 1], name=_TFINPUT)

tf2onnx/tfonnx.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -340,29 +340,6 @@ def reshape_op5(ctx, node, name, args):
340340
return [input_cast] + nodes
341341

342342

343-
def less_op7(ctx, node, name, args):
344-
"""Elementwise Ops with Less-7 flag."""
345-
nodes = [node]
346-
input1_dtype = ctx.get_dtype(node.input[0])
347-
input2_dtype = ctx.get_dtype(node.input[1])
348-
target_dtype = onnx_pb.TensorProto.FLOAT
349-
need_case_1 = input1_dtype != target_dtype
350-
if need_case_1:
351-
input1_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
352-
input1_cast.set_attr("to", target_dtype)
353-
ctx.copy_shape(node.output[0], input1_cast.output[0])
354-
ctx.set_shape(input1_cast.output[0], target_dtype)
355-
nodes.insert(0, input1_cast)
356-
357-
input2_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
358-
input2_cast.set_attr("to", target_dtype)
359-
ctx.copy_shape(node.output[0], input2_cast.output[0])
360-
ctx.set_shape(input2_cast.output[0], target_dtype)
361-
nodes.insert(0, input2_cast)
362-
363-
return nodes
364-
365-
366343
NCHW_TO_NHWC = [0, 2, 3, 1]
367344
NHWC_TO_NCHW = [0, 3, 1, 2]
368345
HWCN_TO_NCHW = [3, 2, 0, 1]
@@ -982,23 +959,6 @@ def expanddims_op(ctx, node, name, args):
982959
raise ValueError("non-const dim is not supported")
983960

984961

985-
def greater_op7(ctx, node, name, args):
986-
nodes = []
987-
supported_types = [
988-
onnx_pb.TensorProto.FLOAT,
989-
onnx_pb.TensorProto.FLOAT16,
990-
onnx_pb.TensorProto.DOUBLE
991-
]
992-
for inp in node.input:
993-
if ctx.get_dtype(inp) not in supported_types:
994-
inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=onnx_pb.TensorProto.FLOAT)
995-
ctx.copy_shape(inp, inp_cast.output[0])
996-
ctx.set_dtype(inp_cast.output[0], onnx_pb.TensorProto.FLOAT)
997-
nodes.append(inp_cast)
998-
nodes.append(broadcast_op7(ctx, node, name, args))
999-
return nodes
1000-
1001-
1002962
def expanddims_op7(ctx, node, name, args):
1003963
# T output = ExpandDims(T input, Tdim dim, @type Tdim), dim is 0-D scalar.
1004964
# T reshaped = Reshape-5(T data, int64 shape)
@@ -1717,6 +1677,27 @@ def softmax_op(ctx, node, name, args):
17171677
return node
17181678

17191679

1680+
def logical_compare_op(ctx, node, name, args):
1681+
# T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
1682+
# T2 output = Less(T1 x, T1 y), T2=tensor(bool)
1683+
nodes = [node]
1684+
# Great/Less in opset7 only supports limited types, insert Cast if needed
1685+
if ctx.opset < 9:
1686+
supported_dtypes = [
1687+
onnx_pb.TensorProto.FLOAT,
1688+
onnx_pb.TensorProto.FLOAT16,
1689+
onnx_pb.TensorProto.DOUBLE
1690+
]
1691+
target_dtype = onnx_pb.TensorProto.FLOAT
1692+
for inp in node.input:
1693+
if ctx.get_dtype(inp) not in supported_dtypes:
1694+
inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=target_dtype)
1695+
ctx.copy_shape(inp, inp_cast.output[0])
1696+
ctx.set_dtype(inp_cast.output[0], target_dtype)
1697+
nodes.append(inp_cast)
1698+
return nodes
1699+
1700+
17201701
# map tensorflow ops to onnx ops. The format below is
17211702
# "TFOP": func_to_map, ["OnnxOp", ...]
17221703
#
@@ -1845,8 +1826,8 @@ def softmax_op(ctx, node, name, args):
18451826
"FloorMod": (floormod_op, []),
18461827
"FusedBatchNorm": (fused_batchnorm_op7, []),
18471828
"FusedBatchNormV2": (fused_batchnorm_op7, []),
1848-
"Greater": (greater_op7, []),
1849-
"Less": (less_op7, []),
1829+
"Greater": (logical_compare_op, []),
1830+
"Less": (logical_compare_op, []),
18501831
"LogicalAnd": (broadcast_op7, ["And"]),
18511832
"LogicalOr": (broadcast_op7, ["Or"]),
18521833
"MatrixBandPart": (matrixbandpart_op, []),
@@ -1884,6 +1865,8 @@ def softmax_op(ctx, node, name, args):
18841865
"Asinh": (direct_op, []),
18851866
"Acosh": (direct_op, []),
18861867
"Atanh": (direct_op, []),
1868+
"Greater": (logical_compare_op, []),
1869+
"Less": (logical_compare_op, []),
18871870
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18881871
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
18891872
}
@@ -2178,9 +2161,9 @@ def rewrite_logical_compare_with_equal(g, ops):
21782161
compare_e_op = match.get_op('compare_with_equal')
21792162
data_type = g.get_dtype(compare_e_op.input[0])
21802163
compare_input_ids = compare_e_op.input
2181-
need_cast = data_type not in (onnx_pb.TensorProto.FLOAT16,
2182-
onnx_pb.TensorProto.FLOAT,
2183-
onnx_pb.TensorProto.DOUBLE)
2164+
need_cast = g.opset < 9 and data_type not in (onnx_pb.TensorProto.FLOAT16,
2165+
onnx_pb.TensorProto.FLOAT,
2166+
onnx_pb.TensorProto.DOUBLE)
21842167
if need_cast:
21852168
compare_input_ids = []
21862169
for input_id in compare_e_op.input:

0 commit comments

Comments
 (0)