Skip to content

Commit 73ee552

Browse files
Fixed dropout rewriter to properly read ratio
1 parent 011cf5e commit 73ee552

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

tf2onnx/rewriter/dropout_rewriter.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from tf2onnx import utils
99
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
from tf2onnx import logging
11+
12+
logger = logging.getLogger(__name__)
1013

1114

1215
# pylint: disable=missing-docstring
@@ -18,7 +21,7 @@ def rewrite_dropout(g, ops):
1821
OpTypePattern('RealDiv', name="input2"),
1922
OpTypePattern('Floor', inputs=[
2023
OpTypePattern('Add', inputs=[
21-
OpTypePattern(None, name="input3"),
24+
OpTypePattern("*", name="input3"),
2225
OpTypePattern('RandomUniform|RandomUniformLike'),
2326
])
2427
]),
@@ -28,7 +31,7 @@ def rewrite_dropout(g, ops):
2831
OpTypePattern("Cast", inputs=[
2932
OpTypePattern("GreaterEqual", inputs=[
3033
OpTypePattern("RandomUniform|RandomUniformLike"),
31-
OpTypePattern(None, name="input3")
34+
OpTypePattern("*", name="input3")
3235
])
3336
])
3437
]),
@@ -37,7 +40,7 @@ def rewrite_dropout(g, ops):
3740
OpTypePattern("Cast", inputs=[
3841
OpTypePattern("GreaterEqual", inputs=[
3942
OpTypePattern("RandomUniform|RandomUniformLike"),
40-
OpTypePattern(None, name="input3")
43+
OpTypePattern("*", name="input3")
4144
])
4245
]),
4346
OpTypePattern("Mul", name="input2"),
@@ -48,10 +51,18 @@ def rewrite_dropout(g, ops):
4851
match_results = list(matcher.match_ops(ops))
4952
for match in match_results:
5053
inputs2 = match.get_op('input2')
54+
inputs3 = match.get_op('input3')
55+
if inputs3.type == "Const":
56+
ratio = inputs3.get_tensor_value()
57+
else:
58+
# If the ratio isn't constant, set it to 0
59+
logger.error("Dropout node has non-constant ratio. Using ratio=0.0")
60+
ratio = 0.0
5161
if inputs2.inputs[0].type == "RealDiv":
5262
data = inputs2.input[1]
5363
else:
5464
data = inputs2.input[0]
65+
# TODO(tomwildenhain): replace dropout node with identity if ratio is 0
5566
outputs = match.get_op('outputs')
5667
op_name = utils.make_name("Dropout")
5768
out_name = utils.port_name(op_name)
@@ -60,17 +71,11 @@ def rewrite_dropout(g, ops):
6071
[data],
6172
outputs=[out_name],
6273
name=op_name,
63-
attr={"ratio": 1.0},
74+
attr={"ratio": ratio},
6475
shapes=[g.get_shape(inputs2.input[0])],
6576
dtypes=[g.get_dtype(inputs2.input[0])]
6677
)
6778
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
6879
g.safe_remove_nodes(match.get_nodes())
6980

70-
# remove dropout if its ratio is 1.0
71-
for node in g.get_nodes():
72-
if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
73-
g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
74-
g.remove_node(node.name)
75-
7681
return ops

0 commit comments

Comments
 (0)