Skip to content

Commit fd569ed

Browse files
committed
fix tf.greater_equal
1 parent be6b986 commit fd569ed

File tree

2 files changed

+26
-61
lines changed

2 files changed

+26
-61
lines changed

tests/test_backend.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -620,23 +620,27 @@ def test_logicaland(self):
620620

621621
@check_onnxruntime_incompatibility("Greater")
622622
def test_greater(self):
623-
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
624-
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
625-
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
626-
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
627-
mi = tf.greater(x1, x2)
628-
_ = tf.identity(mi, name=_TFOUTPUT)
629-
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
623+
for op in [tf.greater, tf.greater_equal]:
624+
tf.reset_default_graph()
625+
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
626+
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
627+
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
628+
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
629+
mi = op(x1, x2)
630+
_ = tf.identity(mi, name=_TFOUTPUT)
631+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
630632

631633
@check_onnxruntime_incompatibility("Greater")
632634
def test_greater_unsupport_type(self):
633-
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
634-
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
635-
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
636-
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
637-
mi = tf.greater(x1, x2)
638-
_ = tf.identity(mi, name=_TFOUTPUT)
639-
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
635+
for op in [tf.greater, tf.greater_equal]:
636+
tf.reset_default_graph()
637+
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
638+
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
639+
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
640+
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
641+
mi = op(x1, x2)
642+
_ = tf.identity(mi, name=_TFOUTPUT)
643+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
640644

641645
@check_onnxruntime_incompatibility("Less")
642646
def test_less(self):

tf2onnx/tfonnx.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,11 @@ def logical_compare_op(ctx, node, name, args):
16331633
ctx.copy_shape(inp, inp_cast.output[0])
16341634
ctx.set_dtype(inp_cast.output[0], target_dtype)
16351635

1636+
def logical_compareeq_op(ctx, node, name, args):
1637+
logical_compare_op(ctx, node, name, [])
1638+
ctx.insert_new_node_on_output("Not", node.output[0], name=utils.make_name(name),
1639+
shapes=ctx.get_shape(node.output[0]), dtypes=ctx.get_dtype(node.output[0]))
1640+
16361641

16371642
def where_op(ctx, node, name, args):
16381643
# T_y output = Where(T_x condition), return indices of elements whose value are True
@@ -1777,6 +1782,8 @@ def where_op(ctx, node, name, args):
17771782
"FusedBatchNormV2": (fused_batchnorm_op7, []),
17781783
"Greater": (logical_compare_op, []),
17791784
"Less": (logical_compare_op, []),
1785+
"GreaterEqual": (logical_compareeq_op, ["Less"]),
1786+
"LessEqual": (logical_compareeq_op, ["Greater"]),
17801787
"LogicalAnd": (broadcast_op7, ["And"]),
17811788
"LogicalOr": (broadcast_op7, ["Or"]),
17821789
"MatrixBandPart": (matrixbandpart_op, []),
@@ -1812,9 +1819,7 @@ def where_op(ctx, node, name, args):
18121819
"Cosh": (direct_op, []),
18131820
"Erf": (direct_op, []),
18141821
"Fill": (fill_op, []),
1815-
"Greater": (logical_compare_op, []),
18161822
"IsNan": (direct_op, ["IsNaN"]),
1817-
"Less": (logical_compare_op, []),
18181823
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18191824
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
18201825
"Sign": (sign_op9, []),
@@ -2105,50 +2110,6 @@ def rewrite_constant_fold(g, ops):
21052110
return ops
21062111

21072112

2108-
def rewrite_logical_compare_with_equal(g, ops):
2109-
patterns = {"GreaterEqual": "Greater",
2110-
"LessEqual": "Less"}
2111-
for p in patterns:
2112-
pattern = OpTypePattern(p, name='compare_with_equal')
2113-
compare_name = patterns[p]
2114-
matcher = GraphMatcher(pattern)
2115-
match_results = list(matcher.match_ops(ops))
2116-
for match in match_results:
2117-
nodes_to_append = []
2118-
compare_e_op = match.get_op('compare_with_equal')
2119-
data_type = g.get_dtype(compare_e_op.input[0])
2120-
compare_input_ids = compare_e_op.input
2121-
need_cast = g.opset < 9 and data_type not in (onnx_pb.TensorProto.FLOAT16,
2122-
onnx_pb.TensorProto.FLOAT,
2123-
onnx_pb.TensorProto.DOUBLE)
2124-
if need_cast:
2125-
compare_input_ids = []
2126-
for input_id in compare_e_op.input:
2127-
cast_node = g.make_node("Cast", [input_id], op_name_scope=compare_e_op.name,
2128-
attr={"to": onnx_pb.TensorProto.FLOAT}, shapes=[g.get_shape(input_id)],
2129-
dtypes=[onnx_pb.TensorProto.FLOAT])
2130-
compare_input_ids.append(cast_node.output[0])
2131-
nodes_to_append.append(cast_node)
2132-
2133-
g_node = g.make_node(compare_name, compare_input_ids, op_name_scope=compare_e_op.name,
2134-
dtypes=[onnx_pb.TensorProto.BOOL])
2135-
set_shape_from_inputs_broadcast(g, compare_input_ids, g_node.output[0])
2136-
new_shape = g.get_shape(g_node.output[0])
2137-
nodes_to_append.append(g_node)
2138-
2139-
e_node = g.make_node("Equal", compare_e_op.input, op_name_scope=compare_e_op.name,
2140-
shapes=[new_shape], dtypes=[onnx_pb.TensorProto.BOOL])
2141-
nodes_to_append.append(e_node)
2142-
2143-
compare_e_op.type = "LogicalOr"
2144-
compare_e_op.input[0] = g_node.output[0]
2145-
compare_e_op.input[1] = e_node.output[0]
2146-
g.set_dtype(compare_e_op.output[0], onnx_pb.TensorProto.BOOL)
2147-
g.set_shape(compare_e_op.output[0], new_shape)
2148-
ops.extend(nodes_to_append)
2149-
return ops
2150-
2151-
21522113
def rewrite_incomplete_type_support(g, ops, impacted_ops):
21532114
"""
21542115
for ops that have inclomplete type support, insert casts.
@@ -2459,7 +2420,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24592420
rewrite_leakyrelu, rewrite_conv2d_with_pad,
24602421
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
24612422
rewrite_single_direction_gru, rewrite_single_direction_grublock,
2462-
rewrite_bi_direction_gru, rewrite_logical_compare_with_equal,
2423+
rewrite_bi_direction_gru,
24632424
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond
24642425
]
24652426

0 commit comments

Comments
 (0)