Skip to content

Commit a8b02aa

Browse files
Merge pull request #995 from onnx/tom/FixDropoutRewriter
Fixed dropout rewriter to properly read ratio
2 parents 38b1a6a + c4ad517 commit a8b02aa

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

tf2onnx/graph.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def output(self, val):
7676
utils.make_sure(o not in self.graph._output_to_node_name, "output %s already in output mapping", o)
7777
self.graph._output_to_node_name[o] = self.name
7878

79+
# TODO(tomwildenhain): Rename to "input_nodes"
7980
@property
8081
def inputs(self):
8182
"""Input node objects."""
@@ -151,6 +152,16 @@ def is_const(self):
151152
"""Return True if node is a constant."""
152153
return self.type in ["Const", "ConstV2"]
153154

155+
def is_scalar(self):
156+
"""Return True if node is a constant with a scalar value."""
157+
if not self.is_const():
158+
return False
159+
t = self.get_attr("value", default=None)
160+
if t is None:
161+
return False
162+
t = numpy_helper.to_array(helper.get_attribute_value(t))
163+
return t.shape == tuple()
164+
154165
def is_graph_input(self):
155166
return self.type in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
156167

@@ -1318,6 +1329,7 @@ def safe_to_remove_nodes(self, to_delete):
13181329
safe_to_remove.append(n)
13191330
return safe_to_remove
13201331

1332+
# TODO(tomwildenhain): Remove this function
13211333
def safe_remove_nodes(self, to_delete):
13221334
"""Delete nodes in `to_delete` without third-party node consuming it."""
13231335
delete_set = set(to_delete)
@@ -1328,6 +1340,20 @@ def safe_remove_nodes(self, to_delete):
13281340
if out_consumers.issubset(delete_set):
13291341
self.remove_node(n.name)
13301342

1343+
def is_safe_to_remove_nodes(self, to_delete, outputs_to_ignore=None):
1344+
"""Returns true if the outputs of all the nodes in to_delete have no third-party nodes consuming them"""
1345+
delete_set = set(to_delete)
1346+
outputs_to_ignore_set = set(outputs_to_ignore or [])
1347+
for n in delete_set:
1348+
out_consumers = set()
1349+
for out in n.output:
1350+
if out in outputs_to_ignore_set:
1351+
continue
1352+
out_consumers |= set(self.find_output_consumers(out))
1353+
if not out_consumers.issubset(delete_set):
1354+
return False
1355+
return True
1356+
13311357

13321358
class GraphUtil(object):
13331359
"""Utilities for Graph manipulation."""

tf2onnx/rewriter/dropout_rewriter.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx dropout op
66
"""
77

8+
import numpy as np
89
from tf2onnx import utils
910
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
11+
from tf2onnx import logging
12+
13+
logger = logging.getLogger(__name__)
1014

1115

1216
# pylint: disable=missing-docstring
@@ -18,7 +22,7 @@ def rewrite_dropout(g, ops):
1822
OpTypePattern('RealDiv', name="input2"),
1923
OpTypePattern('Floor', inputs=[
2024
OpTypePattern('Add', inputs=[
21-
OpTypePattern(None, name="input3"),
25+
OpTypePattern("*", name="input3"),
2226
OpTypePattern('RandomUniform|RandomUniformLike'),
2327
])
2428
]),
@@ -28,7 +32,7 @@ def rewrite_dropout(g, ops):
2832
OpTypePattern("Cast", inputs=[
2933
OpTypePattern("GreaterEqual", inputs=[
3034
OpTypePattern("RandomUniform|RandomUniformLike"),
31-
OpTypePattern(None, name="input3")
35+
OpTypePattern("*", name="input3")
3236
])
3337
])
3438
]),
@@ -37,7 +41,7 @@ def rewrite_dropout(g, ops):
3741
OpTypePattern("Cast", inputs=[
3842
OpTypePattern("GreaterEqual", inputs=[
3943
OpTypePattern("RandomUniform|RandomUniformLike"),
40-
OpTypePattern(None, name="input3")
44+
OpTypePattern("*", name="input3")
4145
])
4246
]),
4347
OpTypePattern("Mul", name="input2"),
@@ -47,30 +51,53 @@ def rewrite_dropout(g, ops):
4751
matcher = GraphMatcher(pattern, allow_reorder=True)
4852
match_results = list(matcher.match_ops(ops))
4953
for match in match_results:
50-
inputs2 = match.get_op('input2')
51-
if inputs2.inputs[0].type == "RealDiv":
52-
data = inputs2.input[1]
53-
else:
54-
data = inputs2.input[0]
54+
input2 = match.get_op('input2')
55+
input3 = match.get_op('input3')
5556
outputs = match.get_op('outputs')
57+
58+
if not input3.is_scalar():
59+
logger.warning("Dropout pattern rooted at %s does not have a "
60+
"constant ratio and cannot be replaced.", outputs.name)
61+
continue
62+
ratio = input3.get_tensor_value()
63+
64+
if input2.inputs[0].is_scalar():
65+
data = input2.inputs[1]
66+
scaling_constant = input2.inputs[0].get_tensor_value()
67+
elif input2.inputs[1].is_scalar():
68+
data = input2.inputs[0]
69+
scaling_constant = input2.inputs[1].get_tensor_value()
70+
else:
71+
logger.warning("Could not find scaling constant for dropout pattern rooted at %s. "
72+
"The pattern will not be replaced with an ONNX dropout node.", outputs.name)
73+
continue
74+
75+
#The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node
76+
if not np.allclose([1], [scaling_constant * (1 - ratio)]):
77+
logger.warning("Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout "
78+
"ratio %f. The pattern will not be replaced with an ONNX dropout node.",
79+
scaling_constant, outputs.name, ratio)
80+
continue
81+
82+
nodes_to_remove = [n for n in match.get_nodes() if n.name != input3.name]
83+
if not g.is_safe_to_remove_nodes(nodes_to_remove, [outputs.output[0]]):
84+
logger.warning("Nodes in dropout pattern rooted at %s cannot be removed because intermediate results "
85+
"of some nodes are referenced elsewhere in graph.", outputs.name)
86+
continue
87+
5688
op_name = utils.make_name("Dropout")
5789
out_name = utils.port_name(op_name)
5890
new_node = g.make_node(
5991
"Dropout",
60-
[data],
92+
inputs=[data.output[0]],
6193
outputs=[out_name],
6294
name=op_name,
63-
attr={"ratio": 1.0},
64-
shapes=[g.get_shape(inputs2.input[0])],
65-
dtypes=[g.get_dtype(inputs2.input[0])]
95+
attr={"ratio": ratio},
96+
shapes=[g.get_shape(data.output[0])],
97+
dtypes=[g.get_dtype(data.output[0])]
6698
)
6799
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
68-
g.safe_remove_nodes(match.get_nodes())
69-
70-
# remove dropout if its ratio is 1.0
71-
for node in g.get_nodes():
72-
if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
73-
g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
74-
g.remove_node(node.name)
100+
for n in nodes_to_remove:
101+
g.remove_node(n.name)
75102

76103
return ops

0 commit comments

Comments
 (0)