Skip to content

Commit a585481

Browse files
authored
Merge pull request #446 from lucienwang1009/dropout_fix
fix dropout bug
2 parents 5243934 + 0b8ce73 commit a585481

File tree

6 files changed

+37
-6
lines changed

6 files changed

+37
-6
lines changed

tests/backend_test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
130130
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
131131
sess.graph_def, constant_fold)
132132

133-
if self.config.is_debug_mode and constant_fold:
133+
if self.config.is_debug_mode:
134134
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
135135
utils.save_protobuf(model_path, graph_def)
136136
self.log.debug("created file %s", model_path)

tests/run_pretrained_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
243243
inputs[k] = v
244244

245245
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
246+
if debug:
247+
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
246248
shape_override = {}
247249
g = tf.import_graph_def(graph_def, name='')
248250
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:

tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,23 @@ def test_dropout(self):
353353
output_names_with_port = ["output:0"]
354354
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
355355

356+
def test_nn_dropout(self):
357+
keep_prob = tf.placeholder_with_default(1., (), "keep_prob")
358+
x_val = np.ones([1, 24, 24, 3], dtype=np.float32)
359+
# Define a scope for reusing the variables
360+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_1")
361+
x_ = tf.identity(x)
362+
363+
fc1 = tf.nn.dropout(x_, keep_prob)
364+
365+
_ = tf.identity(fc1, name="output")
366+
feed_dict = {"input_1:0": x_val}
367+
input_names_with_port = ["input_1:0"]
368+
output_names_with_port = ["output:0"]
369+
# when constant_fold is enabled, PlaceholderWithDefault will be folded into either a const or a placeholder.
370+
# here we set it False to test PlaceholderWithDefault bug: https://github.com/onnx/tensorflow-onnx/pull/446
371+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, constant_fold=False)
372+
356373
def test_conv2d_with_input_transpose(self):
357374
x_shape = [2, 32, 32, 3]
358375
kernel_shape = [3, 3, 3, 3]

tests/test_graph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,9 @@ def test_dropout(self):
165165
g = process_tf_graph(sess.graph, opset=self.config.opset)
166166
actual = onnx_to_graphviz(g)
167167
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
168-
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
169-
'output1 [op_type=Identity] output2 [op_type=Identity] output [op_type=Identity] ' \
170-
'input1:0 -> Add input2:0 -> Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 ' \
171-
'output1:0 -> output2 output2:0 -> output }'
168+
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] output1 [op_type=Identity] ' \
169+
'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> Add ' \
170+
'Add:0 -> output1 output1:0 -> output2 output2:0 -> output }'
172171
self.assertEqual(expected, actual)
173172

174173
def test_add(self):

tf2onnx/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ def _get_unvisited_child(g, node, not_visited):
687687
all_input = list(filter(lambda a: a != '', all_input))
688688
for inp in all_input:
689689
j = self.get_node_by_output(inp)
690+
utils.make_sure(j is not None, "Cannot find node with output {}".format(inp))
690691
if self.parent_graph and j.name not in op_name_to_index:
691692
# there might be some outer-scoped inputs for an inner Graph.
692693
pass
@@ -754,6 +755,7 @@ def make_graph(self, doc, graph_name="tf2onnx"):
754755
placeholder_default_const_ops = []
755756
for op in placeholder_ops:
756757
if op.type == "PlaceholderWithDefault":
758+
utils.make_sure(op.inputs[0] is not None, "Cannot find node with output {}".format(op.input[0]))
757759
utils.make_sure(op.inputs[0].is_const(),
758760
"non-const default value for PlaceholderWithDefault is not supported.")
759761
# copy the tensor value, set its name to current node's output, add as initializer
@@ -1027,6 +1029,11 @@ def extract_sub_graph_nodes(self, outputs_name, input_checker=None, ignore_unuse
10271029
if node.is_graph_input():
10281030
if node not in res_set:
10291031
res_set.add(node)
1032+
if node.type == "PlaceholderWithDefault" and \
1033+
node.inputs[0].is_const() and \
1034+
node.inputs[0] not in res_set:
1035+
res_set.add(node.inputs[0])
1036+
10301037
return list(res_set)
10311038

10321039
def delete_unused_nodes(self, outputs_name):

tf2onnx/tfonnx.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def rewrite_dropout(g, ops):
192192
OpTypePattern('Floor', inputs=[
193193
OpTypePattern('Add', inputs=[
194194
OpTypePattern(None, name="input3"),
195-
OpTypePattern('RandomUniform'),
195+
OpTypePattern('RandomUniform|RandomUniformLike'),
196196
])
197197
]),
198198
])
@@ -216,6 +216,12 @@ def rewrite_dropout(g, ops):
216216
for n in set(match.get_nodes()):
217217
g.remove_node(n.name)
218218

219+
# remove dropout if its ratio is 1.0
220+
for node in g.get_nodes():
221+
if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
222+
g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
223+
g.remove_node(node.name)
224+
219225
return ops
220226

221227

0 commit comments

Comments
 (0)