22
22
from tf2onnx .handler import tf_op
23
23
24
24
from backend_test_base import Tf2OnnxBackendTestBase
25
- from common import unittest_main , check_tf_min_version , check_tf_max_version
25
+ from common import unittest_main
26
26
27
27
28
28
# pylint: disable=missing-docstring,unused-argument,unused-variable
@@ -139,7 +139,6 @@ def test_randomnormal(self):
139
139
'RandomNormal__2:0 -> output }'
140
140
self .assertEqual (expected , actual )
141
141
142
- @check_tf_max_version ("1.12" )
143
142
def test_dropout (self ):
144
143
with tf .Session () as sess :
145
144
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
@@ -150,35 +149,15 @@ def test_dropout(self):
150
149
x_ = tf .identity (x_ , name = "output1" )
151
150
x_ = tf .identity (x_ , name = "output2" )
152
151
_ = tf .identity (x_ , name = "output" )
153
- g = process_tf_graph (sess .graph , opset = self .config .opset )
152
+ # feed output_names in order to remove unused nodes.
153
+ g = process_tf_graph (sess .graph , opset = self .config .opset , output_names = ["output:0" ])
154
+ utils .save_protobuf ("./test.onnx" , g .make_model ("test" ))
154
155
actual = onnx_to_graphviz (g )
155
156
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
156
157
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] output1 [op_type=Identity] ' \
157
- 'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> Add ' \
158
- 'Add:0 -> output1 output1:0 -> output2 output2:0 -> output }'
159
- self .assertEqual (expected , actual )
160
-
161
- @check_tf_min_version ("1.13" )
162
- def test_dropout_2 (self ):
163
- with tf .Session () as sess :
164
- x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
165
- x2 = tf .placeholder (tf .float32 , [1 , 3 ], name = "input2" )
166
- prop = tf .placeholder (tf .float32 , (), name = "prob" )
167
- x_ = tf .add (x1 , x2 )
168
- x_ = tf .nn .dropout (x_ , prop )
169
- x_ = tf .identity (x_ , name = "output1" )
170
- x_ = tf .identity (x_ , name = "output2" )
171
- _ = tf .identity (x_ , name = "output" )
172
- g = process_tf_graph (sess .graph , opset = self .config .opset )
173
- actual = onnx_to_graphviz (g )
174
- expected = 'digraph { "sub/x" [op_type=Const] prob [op_type=Placeholder shape="[]"] ' \
175
- 'sub [op_type=Sub] input2 [op_type=Placeholder shape="[1, 3]"] ' \
176
- 'input1 [op_type=Placeholder shape="[2, 3]"] "dropout/sub/x" [op_type=Const] ' \
177
- '"dropout/sub" [op_type=Sub] Add [op_type=Add] output1 [op_type=Identity] ' \
178
- 'output2 [op_type=Identity] output [op_type=Identity] "sub/x":0 -> sub ' \
179
- 'prob:0 -> sub "dropout/sub/x":0 -> "dropout/sub" sub:0 -> "dropout/sub" ' \
180
- 'input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
181
- 'output2:0 -> output }'
158
+ 'output2 [op_type=Identity] output [op_type=Identity] output_graph_outputs_Identity__3 ' \
159
+ '[op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
160
+ 'output2:0 -> output output_raw_output___2:0 -> output_graph_outputs_Identity__3 }'
182
161
self .assertEqual (expected , actual )
183
162
184
163
def test_add (self ):
0 commit comments