Skip to content

Commit aa19744

Browse files
authored
Merge pull request #625 from lucienwang1009/fix_rewriter
fix rewriter bugs
2 parents 29ab979 + 0f9933f commit aa19744

File tree

7 files changed

+63
-15
lines changed

7 files changed

+63
-15
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,22 @@ def test_leaky_relu_int(self):
925925
self._run_test_case([_OUTPUT], {_INPUT: x_val})
926926
tf.reset_default_graph()
927927

928+
@skip_caffe2_backend("fails on caffe2 with dim issue")
929+
@check_onnxruntime_incompatibility("Mul")
930+
def test_leaky_relu_with_dependency(self):
931+
x_val = 1000 * np.random.random_sample([1000, 100]).astype(np.float32)
932+
x = tf.placeholder(x_val.dtype, [None] * x_val.ndim, name=_TFINPUT)
933+
# simulate leaky_relu
934+
alpha = tf.constant(0.5)
935+
y = alpha * x
936+
x_ = tf.maximum(y, x)
937+
dependency = y - 1
938+
939+
_ = tf.identity(x_, name=_TFOUTPUT)
940+
_ = tf.identity(dependency, name=_TFOUTPUT1)
941+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
942+
tf.reset_default_graph()
943+
928944
@skip_caffe2_backend("fails on caffe2 with dim issue")
929945
@check_onnxruntime_incompatibility("Mul")
930946
def test_leaky_relu_float(self):

tests/test_loops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313

1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main, check_tf_min_version
15+
from common import unittest_main, check_tf_min_version, check_onnxruntime_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -160,6 +160,10 @@ def b(i, out_ta):
160160
output_names_with_port = ["i:0", "output_ta:0"]
161161
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
162162

163+
@check_onnxruntime_min_version(
164+
"0.5.0",
165+
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"
166+
)
163167
def test_while_loop_with_cond_init_false(self):
164168
i = tf.placeholder(tf.int32, (), name="input_1")
165169
inputs = tf.placeholder(tf.float32, (10,), name="input_2")

tf2onnx/graph.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,18 @@ def reset_nodes(self, ops):
575575
self._dtypes = remained_dtypes
576576
self._output_shapes = remained_shapes
577577

578+
def check_integrity(self):
579+
"""
580+
Check graph integrity. Every node's input needs to associate with a node.
581+
Return broken outputs.
582+
"""
583+
broken_outputs = set()
584+
for node in self.get_nodes():
585+
for inp in node.input:
586+
if self.get_node_by_output(inp) is None:
587+
broken_outputs.add(inp)
588+
return list(broken_outputs)
589+
578590
def update_node_shape_dtype(self, node, override=False):
579591
"""Try the best to infer shapes and dtypes for outputs of the node,
580592
by default, we respect TF shapes and dtypes.
@@ -1169,6 +1181,16 @@ def delete_unused_nodes(self, outputs_name):
11691181
body_graph.delete_unused_nodes(body_graph.outputs)
11701182
self.reset_nodes(related_nodes)
11711183

1184+
def safe_remove_nodes(self, to_delete):
1185+
"""Delete nodes in `to_delete` without third-party node consuming it."""
1186+
delete_set = set(to_delete)
1187+
for n in delete_set:
1188+
out_consumers = set()
1189+
for out in n.output:
1190+
out_consumers |= set(self.find_output_consumers(out))
1191+
if out_consumers.issubset(delete_set):
1192+
self.remove_node(n.name)
1193+
11721194

11731195
class GraphUtil(object):
11741196
"""Utilities for Graph manipulation."""

tf2onnx/rewriter/leakyrelu_rewriter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def rewrite_leakyrelu(g, ops):
4242
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
4343
ops.append(leakyrelu)
4444
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
45+
to_delete = [max_node, mul_node]
46+
g.safe_remove_nodes(to_delete)
4547

4648
return ops
4749

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def rewrite_random_uniform_fold_const(g, ops):
6363
to_delete = list(set(match.get_nodes()))
6464
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
6565
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
66-
for n in to_delete:
67-
g.remove_node(n.name)
66+
g.safe_remove_nodes(to_delete)
6867

6968
return ops
7069

tf2onnx/rewriter/thresholded_relu_rewriter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def rewrite_thresholded_relu(g, ops):
3434
greater_input_node = match.get_op('greater_input')
3535
mul_node = match.get_op("mul")
3636
mul_input_node = match.get_op('mul_input')
37+
cast_node = match.get_op('cast')
3738

3839
greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
3940
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
@@ -42,6 +43,7 @@ def rewrite_thresholded_relu(g, ops):
4243
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
4344
shapes=[g.get_shape(mul_node.output[0])],
4445
dtypes=[g.get_dtype(mul_node.output[0])])
45-
ops.append(thresholded_relu)
4646
g.replace_all_inputs(ops, mul_node.output[0], thresholded_relu.output[0])
47+
to_delete = [cast_node, mul_node]
48+
g.safe_remove_nodes(to_delete)
4749
return ops

tf2onnx/tfonnx.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,8 @@ def rewrite_transpose(g, ops):
141141
dims = [i for i in range(len(shape) - 1, -1, -1)]
142142
output.set_attr("perm", dims)
143143
g.remove_input(output, output.input[1])
144-
for n in set(match.get_nodes()):
145-
if n != output:
146-
g.remove_node(n.name)
144+
to_delete = [n for n in match.get_nodes() if n != output]
145+
g.safe_remove_nodes(to_delete)
147146
return ops
148147

149148

@@ -175,8 +174,7 @@ def rewrite_random_normal(g, ops):
175174
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})
176175

177176
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
178-
for n in set(match.get_nodes()):
179-
g.remove_node(n.name)
177+
g.safe_remove_nodes(match.get_nodes())
180178
return ops
181179

182180

@@ -208,8 +206,7 @@ def rewrite_dropout(g, ops):
208206
dtypes=[g.get_dtype(inputs2.input[0])]
209207
)
210208
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
211-
for n in set(match.get_nodes()):
212-
g.remove_node(n.name)
209+
g.safe_remove_nodes(match.get_nodes())
213210

214211
# remove dropout if its ratio is 1.0
215212
for node in g.get_nodes():
@@ -294,10 +291,8 @@ def rewrite_flatten(g, ops):
294291

295292
g.set_shape(out_name, input_shape[:-2] + [new_dim])
296293
g.replace_all_inputs(ops, reshape_node.output[0], out_name)
297-
298-
for n in set(match.get_nodes()):
299-
if n != input_node:
300-
g.remove_node(n.name)
294+
to_delete = [n for n in match.get_nodes() if n != input_node]
295+
g.safe_remove_nodes(to_delete)
301296

302297
return ops
303298

@@ -654,6 +649,14 @@ def run_rewriters(g, funcs, continue_on_error):
654649
else:
655650
raise ex
656651

652+
if utils.is_debug_mode():
653+
broken_outputs = g.check_integrity()
654+
if broken_outputs:
655+
logging.error(
656+
"After rewriter %s, graph breaks at outputs %s",
657+
func.__name__, broken_outputs
658+
)
659+
657660
if g.contained_graphs:
658661
for dict_val in g.contained_graphs.values():
659662
for attr_name, b_g in dict_val.items():

0 commit comments

Comments
 (0)