Skip to content

Commit 53c924a

Browse files
align dropout with tf 1.14
1 parent 0f69340 commit 53c924a

File tree

4 files changed

+53
-50
lines changed

4 files changed

+53
-50
lines changed

tests/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def validate_const_node(node, expected_val):
297297
def group_nodes_by_type(graph):
298298
res = defaultdict(list)
299299
for node in graph.get_nodes():
300+
attr_body_graphs = node.get_body_graphs()
301+
if attr_body_graphs:
302+
for _, body_graph in attr_body_graphs.items():
303+
body_graph_res = group_nodes_by_type(body_graph)
304+
for k, v in body_graph_res.items():
305+
res[k].extend(v)
300306
res[node.type].append(node)
301307
return res
302308

tests/test_backend.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,9 @@ def test_dropout(self):
441441
feed_dict = {"input_1:0": x_val}
442442
input_names_with_port = ["input_1:0"]
443443
output_names_with_port = ["output:0"]
444-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
444+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port,
445+
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
446+
check_op_count(g, "RandomUniformLike", 0)))
445447

446448
def test_nn_dropout(self):
447449
keep_prob = tf.placeholder_with_default(1., (), "keep_prob")
@@ -458,7 +460,10 @@ def test_nn_dropout(self):
458460
output_names_with_port = ["output:0"]
459461
# when constant_fold is enabled, PlaceholderWithDefault will be folded into either a const or a placeholder.
460462
# here we set it False to test PlaceholderWithDefault bug: https://github.com/onnx/tensorflow-onnx/pull/446
461-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False)
463+
# Dropout with ratio 1.0 will be optimized so that only one Identity is left
464+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False,
465+
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
466+
check_op_count(g, "RandomUniformLike", 0)))
462467

463468
@check_tf_min_version("1.13")
464469
def test_nn_dropout_with_rate(self):
@@ -474,7 +479,9 @@ def test_nn_dropout_with_rate(self):
474479
feed_dict = {"input_1:0": x_val}
475480
input_names_with_port = ["input_1:0"]
476481
output_names_with_port = ["output:0"]
477-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False)
482+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False,
483+
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
484+
check_op_count(g, "RandomUniformLike", 0)))
478485

479486
def test_conv2d_with_input_transpose(self):
480487
x_shape = [2, 32, 32, 3]

tests/test_graph.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tf2onnx.handler import tf_op
2323

2424
from backend_test_base import Tf2OnnxBackendTestBase
25-
from common import unittest_main, check_tf_min_version, check_tf_max_version
25+
from common import unittest_main
2626

2727

2828
# pylint: disable=missing-docstring,unused-argument,unused-variable
@@ -139,7 +139,6 @@ def test_randomnormal(self):
139139
'RandomNormal__2:0 -> output }'
140140
self.assertEqual(expected, actual)
141141

142-
@check_tf_max_version("1.12")
143142
def test_dropout(self):
144143
with tf.Session() as sess:
145144
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
@@ -150,35 +149,15 @@ def test_dropout(self):
150149
x_ = tf.identity(x_, name="output1")
151150
x_ = tf.identity(x_, name="output2")
152151
_ = tf.identity(x_, name="output")
153-
g = process_tf_graph(sess.graph, opset=self.config.opset)
152+
# feed output_names in order to remove unused nodes.
153+
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=["output:0"])
154+
utils.save_protobuf("./test.onnx", g.make_model("test"))
154155
actual = onnx_to_graphviz(g)
155156
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
156157
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] output1 [op_type=Identity] ' \
157-
'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> Add ' \
158-
'Add:0 -> output1 output1:0 -> output2 output2:0 -> output }'
159-
self.assertEqual(expected, actual)
160-
161-
@check_tf_min_version("1.13")
162-
def test_dropout_2(self):
163-
with tf.Session() as sess:
164-
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
165-
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
166-
prop = tf.placeholder(tf.float32, (), name="prob")
167-
x_ = tf.add(x1, x2)
168-
x_ = tf.nn.dropout(x_, prop)
169-
x_ = tf.identity(x_, name="output1")
170-
x_ = tf.identity(x_, name="output2")
171-
_ = tf.identity(x_, name="output")
172-
g = process_tf_graph(sess.graph, opset=self.config.opset)
173-
actual = onnx_to_graphviz(g)
174-
expected = 'digraph { "sub/x" [op_type=Const] prob [op_type=Placeholder shape="[]"] ' \
175-
'sub [op_type=Sub] input2 [op_type=Placeholder shape="[1, 3]"] ' \
176-
'input1 [op_type=Placeholder shape="[2, 3]"] "dropout/sub/x" [op_type=Const] ' \
177-
'"dropout/sub" [op_type=Sub] Add [op_type=Add] output1 [op_type=Identity] ' \
178-
'output2 [op_type=Identity] output [op_type=Identity] "sub/x":0 -> sub ' \
179-
'prob:0 -> sub "dropout/sub/x":0 -> "dropout/sub" sub:0 -> "dropout/sub" ' \
180-
'input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
181-
'output2:0 -> output }'
158+
'output2 [op_type=Identity] output [op_type=Identity] output_graph_outputs_Identity__3 ' \
159+
'[op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
160+
'output2:0 -> output output_raw_output___2:0 -> output_graph_outputs_Identity__3 }'
182161
self.assertEqual(expected, actual)
183162

184163
def test_add(self):

tf2onnx/tfonnx.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def rewrite_random_normal(g, ops):
179179

180180

181181
def rewrite_dropout(g, ops):
182-
pattern = \
182+
patterns = [
183183
OpTypePattern('Mul', name='outputs', inputs=[
184184
OpTypePattern('RealDiv', name="input2"),
185185
OpTypePattern('Floor', inputs=[
@@ -188,25 +188,36 @@ def rewrite_dropout(g, ops):
188188
OpTypePattern('RandomUniform|RandomUniformLike'),
189189
])
190190
]),
191+
]),
192+
OpTypePattern("Mul", name="outputs", inputs=[
193+
OpTypePattern("Mul", name="input2"),
194+
OpTypePattern("Cast", inputs=[
195+
OpTypePattern("GreaterEqual", inputs=[
196+
OpTypePattern("RandomUniform|RandomUniformLike"),
197+
OpTypePattern(None, name="input3")
198+
])
199+
])
191200
])
192-
matcher = GraphMatcher(pattern)
193-
match_results = list(matcher.match_ops(ops))
194-
for match in match_results:
195-
inputs2 = match.get_op('input2')
196-
outputs = match.get_op('outputs')
197-
op_name = utils.make_name("Dropout")
198-
out_name = port_name(op_name)
199-
new_node = g.make_node(
200-
"Dropout",
201-
[inputs2.input[0]],
202-
outputs=[out_name],
203-
name=op_name,
204-
attr={"ratio": 1.0},
205-
shapes=[g.get_shape(inputs2.input[0])],
206-
dtypes=[g.get_dtype(inputs2.input[0])]
207-
)
208-
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
209-
g.safe_remove_nodes(match.get_nodes())
201+
]
202+
for pattern in patterns:
203+
matcher = GraphMatcher(pattern)
204+
match_results = list(matcher.match_ops(ops))
205+
for match in match_results:
206+
inputs2 = match.get_op('input2')
207+
outputs = match.get_op('outputs')
208+
op_name = utils.make_name("Dropout")
209+
out_name = port_name(op_name)
210+
new_node = g.make_node(
211+
"Dropout",
212+
[inputs2.input[0]],
213+
outputs=[out_name],
214+
name=op_name,
215+
attr={"ratio": 1.0},
216+
shapes=[g.get_shape(inputs2.input[0])],
217+
dtypes=[g.get_dtype(inputs2.input[0])]
218+
)
219+
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
220+
g.safe_remove_nodes(match.get_nodes())
210221

211222
# remove dropout if its ratio is 1.0
212223
for node in g.get_nodes():

0 commit comments

Comments
 (0)