Skip to content

Commit 0d3b82a

Browse files
Fixed bug with removing too many nodes in dropout_rewriter
1 parent 73ee552 commit 0d3b82a

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

tf2onnx/rewriter/dropout_rewriter.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,18 @@ def rewrite_dropout(g, ops):
5050
matcher = GraphMatcher(pattern, allow_reorder=True)
5151
match_results = list(matcher.match_ops(ops))
5252
for match in match_results:
53-
inputs2 = match.get_op('input2')
54-
inputs3 = match.get_op('input3')
55-
if inputs3.type == "Const":
56-
ratio = inputs3.get_tensor_value()
53+
input2 = match.get_op('input2')
54+
input3 = match.get_op('input3')
55+
if input3.is_const():
56+
ratio = input3.get_tensor_value()
5757
else:
5858
# If the ratio isn't constant, set it to 0
59-
logger.error("Dropout node has non-constant ratio. Using ratio=0.0")
59+
logger.warning("Dropout node has non-constant ratio. Using ratio=0.0")
6060
ratio = 0.0
61-
if inputs2.inputs[0].type == "RealDiv":
62-
data = inputs2.input[1]
61+
if input2.inputs[0].type == "RealDiv":
62+
data = input2.input[1]
6363
else:
64-
data = inputs2.input[0]
64+
data = input2.input[0]
6565
# TODO(tomwildenhain): replace dropout node with identity if ratio is 0
6666
outputs = match.get_op('outputs')
6767
op_name = utils.make_name("Dropout")
@@ -72,10 +72,19 @@ def rewrite_dropout(g, ops):
7272
outputs=[out_name],
7373
name=op_name,
7474
attr={"ratio": ratio},
75-
shapes=[g.get_shape(inputs2.input[0])],
76-
dtypes=[g.get_dtype(inputs2.input[0])]
75+
shapes=[g.get_shape(input2.input[0])],
76+
dtypes=[g.get_dtype(input2.input[0])]
7777
)
7878
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
79-
g.safe_remove_nodes(match.get_nodes())
79+
nodes_to_remove = []
80+
for node in match.get_nodes():
81+
if node.name != input3.name:
82+
nodes_to_remove.append(node)
83+
if g.safe_to_remove_nodes(nodes_to_remove):
84+
for n in nodes_to_remove:
85+
g.remove_node(n.name)
86+
else:
87+
logger.warning("Nodes replaced by dropout node cannot be removed because intermediate results are "
88+
"referenced elsewhere in graph")
8089

8190
return ops

0 commit comments

Comments
 (0)