Skip to content

Commit 2808529

Browse files
add unittest
1 parent 472d47d commit 2808529

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,22 @@ def test_leaky_relu_int(self):
925925
self._run_test_case([_OUTPUT], {_INPUT: x_val})
926926
tf.reset_default_graph()
927927

928+
@skip_caffe2_backend("fails on caffe2 with dim issue")
929+
@check_onnxruntime_incompatibility("Mul")
930+
def test_leaky_relu_with_dependency(self):
931+
x_val = 1000 * np.random.random_sample([1000, 100]).astype(np.float32)
932+
x = tf.placeholder(x_val.dtype, [None] * x_val.ndim, name=_TFINPUT)
933+
# simulate leaky_relu
934+
alpha = tf.constant(0.5)
935+
y = alpha * x
936+
x_ = tf.maximum(y, x)
937+
dependency = y - 1
938+
939+
_ = tf.identity(x_, name=_TFOUTPUT)
940+
_ = tf.identity(dependency, name=_TFOUTPUT1)
941+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
942+
tf.reset_default_graph()
943+
928944
@skip_caffe2_backend("fails on caffe2 with dim issue")
929945
@check_onnxruntime_incompatibility("Mul")
930946
def test_leaky_relu_float(self):

tests/test_loops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11+
import unittest
12+
1113
import numpy as np
1214
import tensorflow as tf
13-
import unittest
1415

1516
from backend_test_base import Tf2OnnxBackendTestBase
1617
from common import unittest_main, check_tf_min_version

tf2onnx/graph.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,16 +1183,12 @@ def delete_unused_nodes(self, outputs_name):
11831183

11841184
def delete_nodes_without_dependency(self, to_delete):
11851185
"""Delete nodes in `to_delete` without third-party dependency."""
1186+
delete_set = set(to_delete)
11861187
for n in to_delete:
1187-
can_delete = True
1188+
out_consumers = set()
11881189
for out in n.output:
1189-
if not can_delete:
1190-
break
1191-
for consumer in self.find_output_consumers(out):
1192-
if consumer not in to_delete:
1193-
can_delete = False
1194-
break
1195-
if can_delete:
1190+
out_consumers |= set(self.find_output_consumers(out))
1191+
if out_consumers.issubset(delete_set):
11961192
self.remove_node(n.name)
11971193

11981194

0 commit comments

Comments
 (0)