Skip to content

Commit 143e2dd

Browse files
authored
Merge pull request #389 from onnx/gs/fix-ge
fix tf.greater_equal
2 parents be6b986 + 95a98fc commit 143e2dd

File tree

2 files changed

+29
-62
lines changed

2 files changed

+29
-62
lines changed

tests/test_backend.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -620,23 +620,27 @@ def test_logicaland(self):
620620

621621
@check_onnxruntime_incompatibility("Greater")
622622
def test_greater(self):
623-
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
624-
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
625-
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
626-
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
627-
mi = tf.greater(x1, x2)
628-
_ = tf.identity(mi, name=_TFOUTPUT)
629-
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
623+
for op in [tf.greater, tf.greater_equal]:
624+
tf.reset_default_graph()
625+
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
626+
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
627+
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
628+
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
629+
mi = op(x1, x2)
630+
_ = tf.identity(mi, name=_TFOUTPUT)
631+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
630632

631633
@check_onnxruntime_incompatibility("Greater")
632634
def test_greater_unsupport_type(self):
633-
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
634-
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
635-
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
636-
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
637-
mi = tf.greater(x1, x2)
638-
_ = tf.identity(mi, name=_TFOUTPUT)
639-
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
635+
for op in [tf.greater, tf.greater_equal]:
636+
tf.reset_default_graph()
637+
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
638+
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
639+
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
640+
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
641+
mi = op(x1, x2)
642+
_ = tf.identity(mi, name=_TFOUTPUT)
643+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
640644

641645
@check_onnxruntime_incompatibility("Less")
642646
def test_less(self):

tf2onnx/tfonnx.py

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tf2onnx.rewriter.rnn import rewrite_single_direction_gru
3434
from tf2onnx.rewriter.rnn import rewrite_single_direction_grublock
3535
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm
36-
from tf2onnx.shape_inference import infer_shape_for_graph, set_shape_from_inputs_broadcast
36+
from tf2onnx.shape_inference import infer_shape_for_graph
3737
from tf2onnx.utils import port_name
3838

3939
logging.basicConfig(level=logging.INFO)
@@ -1633,6 +1633,13 @@ def logical_compare_op(ctx, node, name, args):
16331633
ctx.copy_shape(inp, inp_cast.output[0])
16341634
ctx.set_dtype(inp_cast.output[0], target_dtype)
16351635

1636+
def logical_compareeq_op(ctx, node, name, args):
1637+
logical_compare_op(ctx, node, name, [])
1638+
output_name = node.output[0]
1639+
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(name))
1640+
ctx.copy_shape(output_name, new_node.output[0])
1641+
ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))
1642+
16361643

16371644
def where_op(ctx, node, name, args):
16381645
# T_y output = Where(T_x condition), return indices of elements whose value are True
@@ -1777,6 +1784,8 @@ def where_op(ctx, node, name, args):
17771784
"FusedBatchNormV2": (fused_batchnorm_op7, []),
17781785
"Greater": (logical_compare_op, []),
17791786
"Less": (logical_compare_op, []),
1787+
"GreaterEqual": (logical_compareeq_op, ["Less"]),
1788+
"LessEqual": (logical_compareeq_op, ["Greater"]),
17801789
"LogicalAnd": (broadcast_op7, ["And"]),
17811790
"LogicalOr": (broadcast_op7, ["Or"]),
17821791
"MatrixBandPart": (matrixbandpart_op, []),
@@ -1812,9 +1821,7 @@ def where_op(ctx, node, name, args):
18121821
"Cosh": (direct_op, []),
18131822
"Erf": (direct_op, []),
18141823
"Fill": (fill_op, []),
1815-
"Greater": (logical_compare_op, []),
18161824
"IsNan": (direct_op, ["IsNaN"]),
1817-
"Less": (logical_compare_op, []),
18181825
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18191826
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
18201827
"Sign": (sign_op9, []),
@@ -2105,50 +2112,6 @@ def rewrite_constant_fold(g, ops):
21052112
return ops
21062113

21072114

2108-
def rewrite_logical_compare_with_equal(g, ops):
2109-
patterns = {"GreaterEqual": "Greater",
2110-
"LessEqual": "Less"}
2111-
for p in patterns:
2112-
pattern = OpTypePattern(p, name='compare_with_equal')
2113-
compare_name = patterns[p]
2114-
matcher = GraphMatcher(pattern)
2115-
match_results = list(matcher.match_ops(ops))
2116-
for match in match_results:
2117-
nodes_to_append = []
2118-
compare_e_op = match.get_op('compare_with_equal')
2119-
data_type = g.get_dtype(compare_e_op.input[0])
2120-
compare_input_ids = compare_e_op.input
2121-
need_cast = g.opset < 9 and data_type not in (onnx_pb.TensorProto.FLOAT16,
2122-
onnx_pb.TensorProto.FLOAT,
2123-
onnx_pb.TensorProto.DOUBLE)
2124-
if need_cast:
2125-
compare_input_ids = []
2126-
for input_id in compare_e_op.input:
2127-
cast_node = g.make_node("Cast", [input_id], op_name_scope=compare_e_op.name,
2128-
attr={"to": onnx_pb.TensorProto.FLOAT}, shapes=[g.get_shape(input_id)],
2129-
dtypes=[onnx_pb.TensorProto.FLOAT])
2130-
compare_input_ids.append(cast_node.output[0])
2131-
nodes_to_append.append(cast_node)
2132-
2133-
g_node = g.make_node(compare_name, compare_input_ids, op_name_scope=compare_e_op.name,
2134-
dtypes=[onnx_pb.TensorProto.BOOL])
2135-
set_shape_from_inputs_broadcast(g, compare_input_ids, g_node.output[0])
2136-
new_shape = g.get_shape(g_node.output[0])
2137-
nodes_to_append.append(g_node)
2138-
2139-
e_node = g.make_node("Equal", compare_e_op.input, op_name_scope=compare_e_op.name,
2140-
shapes=[new_shape], dtypes=[onnx_pb.TensorProto.BOOL])
2141-
nodes_to_append.append(e_node)
2142-
2143-
compare_e_op.type = "LogicalOr"
2144-
compare_e_op.input[0] = g_node.output[0]
2145-
compare_e_op.input[1] = e_node.output[0]
2146-
g.set_dtype(compare_e_op.output[0], onnx_pb.TensorProto.BOOL)
2147-
g.set_shape(compare_e_op.output[0], new_shape)
2148-
ops.extend(nodes_to_append)
2149-
return ops
2150-
2151-
21522115
def rewrite_incomplete_type_support(g, ops, impacted_ops):
21532116
"""
21542117
for ops that have inclomplete type support, insert casts.
@@ -2459,7 +2422,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24592422
rewrite_leakyrelu, rewrite_conv2d_with_pad,
24602423
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
24612424
rewrite_single_direction_gru, rewrite_single_direction_grublock,
2462-
rewrite_bi_direction_gru, rewrite_logical_compare_with_equal,
2425+
rewrite_bi_direction_gru,
24632426
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond
24642427
]
24652428

0 commit comments

Comments
 (0)