File tree Expand file tree Collapse file tree 3 files changed +22
-9
lines changed Expand file tree Collapse file tree 3 files changed +22
-9
lines changed Original file line number Diff line number Diff line change @@ -925,6 +925,22 @@ def test_leaky_relu_int(self):
925
925
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
926
926
tf .reset_default_graph ()
927
927
928
+ @skip_caffe2_backend ("fails on caffe2 with dim issue" )
929
+ @check_onnxruntime_incompatibility ("Mul" )
930
+ def test_leaky_relu_with_dependency (self ):
931
+ x_val = 1000 * np .random .random_sample ([1000 , 100 ]).astype (np .float32 )
932
+ x = tf .placeholder (x_val .dtype , [None ] * x_val .ndim , name = _TFINPUT )
933
+ # simulate leaky_relu
934
+ alpha = tf .constant (0.5 )
935
+ y = alpha * x
936
+ x_ = tf .maximum (y , x )
937
+ dependency = y - 1
938
+
939
+ _ = tf .identity (x_ , name = _TFOUTPUT )
940
+ _ = tf .identity (dependency , name = _TFOUTPUT1 )
941
+ self ._run_test_case ([_OUTPUT , _OUTPUT1 ], {_INPUT : x_val })
942
+ tf .reset_default_graph ()
943
+
928
944
@skip_caffe2_backend ("fails on caffe2 with dim issue" )
929
945
@check_onnxruntime_incompatibility ("Mul" )
930
946
def test_leaky_relu_float (self ):
Original file line number Diff line number Diff line change 8
8
from __future__ import print_function
9
9
from __future__ import unicode_literals
10
10
11
+ import unittest
12
+
11
13
import numpy as np
12
14
import tensorflow as tf
13
- import unittest
14
15
15
16
from backend_test_base import Tf2OnnxBackendTestBase
16
17
from common import unittest_main , check_tf_min_version
Original file line number Diff line number Diff line change @@ -1183,16 +1183,12 @@ def delete_unused_nodes(self, outputs_name):
1183
1183
1184
1184
def delete_nodes_without_dependency (self , to_delete ):
1185
1185
"""Delete nodes in `to_delete` without third-party dependency."""
1186
+ delete_set = set (to_delete )
1186
1187
for n in to_delete :
1187
- can_delete = True
1188
+ out_consumers = set ()
1188
1189
for out in n .output :
1189
- if not can_delete :
1190
- break
1191
- for consumer in self .find_output_consumers (out ):
1192
- if consumer not in to_delete :
1193
- can_delete = False
1194
- break
1195
- if can_delete :
1190
+ out_consumers |= set (self .find_output_consumers (out ))
1191
+ if out_consumers .issubset (delete_set ):
1196
1192
self .remove_node (n .name )
1197
1193
1198
1194
You can’t perform that action at this time.
0 commit comments