Skip to content

Commit c7ef119

Browse files
remove dropout with ratio 1.0
1 parent 70d94c2 commit c7ef119

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

tests/test_graph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,9 @@ def test_dropout(self):
165165
g = process_tf_graph(sess.graph, opset=self.config.opset)
166166
actual = onnx_to_graphviz(g)
167167
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
168-
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
169-
'output1 [op_type=Identity] output2 [op_type=Identity] output [op_type=Identity] ' \
170-
'input1:0 -> Add input2:0 -> Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 ' \
171-
'output1:0 -> output2 output2:0 -> output }'
168+
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] output1 [op_type=Identity] ' \
169+
'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> Add ' \
170+
'Add:0 -> output1 output1:0 -> output2 output2:0 -> output }'
172171
self.assertEqual(expected, actual)
173172

174173
def test_add(self):

tf2onnx/tfonnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,12 @@ def rewrite_dropout(g, ops):
216216
for n in set(match.get_nodes()):
217217
g.remove_node(n.name)
218218

219+
# remove dropout if its ratio is 1.0
220+
for node in g.get_nodes():
221+
if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
222+
g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
223+
g.remove_node(node.name)
224+
219225
return ops
220226

221227

0 commit comments

Comments
 (0)