File tree Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -165,10 +165,9 @@ def test_dropout(self):
165165 g = process_tf_graph (sess .graph , opset = self .config .opset )
166166 actual = onnx_to_graphviz (g )
167167 expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
168- 'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
169- 'output1 [op_type=Identity] output2 [op_type=Identity] output [op_type=Identity] ' \
170- 'input1:0 -> Add input2:0 -> Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 ' \
171- 'output1:0 -> output2 output2:0 -> output }'
168+ 'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] output1 [op_type=Identity] ' \
169+ 'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> Add ' \
170+ 'Add:0 -> output1 output1:0 -> output2 output2:0 -> output }'
172171 self .assertEqual (expected , actual )
173172
174173 def test_add (self ):
Original file line number Diff line number Diff line change @@ -216,6 +216,12 @@ def rewrite_dropout(g, ops):
216216 for n in set (match .get_nodes ()):
217217 g .remove_node (n .name )
218218
219+ # remove dropout if its ratio is 1.0
220+ for node in g .get_nodes ():
221+ if node .type == "Dropout" and node .get_attr ("ratio" ).f == 1.0 :
222+ g .replace_all_inputs (g .get_nodes (), node .output [0 ], node .input [0 ])
223+ g .remove_node (node .name )
224+
219225 return ops
220226
221227
You can’t perform that action at this time.
0 commit comments