Skip to content

Commit c4ad517

Browse files
Added error checking to dropout_rewriter to avoid incorrect matches
1 parent 0d3b82a commit c4ad517

File tree

2 files changed

+63
-24
lines changed

2 files changed

+63
-24
lines changed

tf2onnx/graph.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def output(self, val):
7676
utils.make_sure(o not in self.graph._output_to_node_name, "output %s already in output mapping", o)
7777
self.graph._output_to_node_name[o] = self.name
7878

79+
# TODO(tomwildenhain): Rename to "input_nodes"
7980
@property
8081
def inputs(self):
8182
"""Input node objects."""
@@ -151,6 +152,16 @@ def is_const(self):
151152
"""Return True if node is a constant."""
152153
return self.type in ["Const", "ConstV2"]
153154

155+
def is_scalar(self):
156+
"""Return True if node is a constant with a scalar value."""
157+
if not self.is_const():
158+
return False
159+
t = self.get_attr("value", default=None)
160+
if t is None:
161+
return False
162+
t = numpy_helper.to_array(helper.get_attribute_value(t))
163+
return t.shape == tuple()
164+
154165
def is_graph_input(self):
155166
return self.type in ["Placeholder", "PlaceholderWithDefault", "PlaceholderV2"]
156167

@@ -1318,6 +1329,7 @@ def safe_to_remove_nodes(self, to_delete):
13181329
safe_to_remove.append(n)
13191330
return safe_to_remove
13201331

1332+
# TODO(tomwildenhain): Remove this function
13211333
def safe_remove_nodes(self, to_delete):
13221334
"""Delete nodes in `to_delete` without third-party node consuming it."""
13231335
delete_set = set(to_delete)
@@ -1328,6 +1340,20 @@ def safe_remove_nodes(self, to_delete):
13281340
if out_consumers.issubset(delete_set):
13291341
self.remove_node(n.name)
13301342

1343+
def is_safe_to_remove_nodes(self, to_delete, outputs_to_ignore=None):
1344+
"""Returns true if the outputs of all the nodes in to_delete have no third-party nodes consuming them"""
1345+
delete_set = set(to_delete)
1346+
outputs_to_ignore_set = set(outputs_to_ignore or [])
1347+
for n in delete_set:
1348+
out_consumers = set()
1349+
for out in n.output:
1350+
if out in outputs_to_ignore_set:
1351+
continue
1352+
out_consumers |= set(self.find_output_consumers(out))
1353+
if not out_consumers.issubset(delete_set):
1354+
return False
1355+
return True
1356+
13311357

13321358
class GraphUtil(object):
13331359
"""Utilities for Graph manipulation."""

tf2onnx/rewriter/dropout_rewriter.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx dropout op
66
"""
77

8+
import numpy as np
89
from tf2onnx import utils
910
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
1011
from tf2onnx import logging
@@ -52,39 +53,51 @@ def rewrite_dropout(g, ops):
5253
for match in match_results:
5354
input2 = match.get_op('input2')
5455
input3 = match.get_op('input3')
55-
if input3.is_const():
56-
ratio = input3.get_tensor_value()
57-
else:
58-
# If the ratio isn't constant, set it to 0
59-
logger.warning("Dropout node has non-constant ratio. Using ratio=0.0")
60-
ratio = 0.0
61-
if input2.inputs[0].type == "RealDiv":
62-
data = input2.input[1]
63-
else:
64-
data = input2.input[0]
65-
# TODO(tomwildenhain): replace dropout node with identity if ratio is 0
6656
outputs = match.get_op('outputs')
57+
58+
if not input3.is_scalar():
59+
logger.warning("Dropout pattern rooted at %s does not have a "
60+
"constant ratio and cannot be replaced.", outputs.name)
61+
continue
62+
ratio = input3.get_tensor_value()
63+
64+
if input2.inputs[0].is_scalar():
65+
data = input2.inputs[1]
66+
scaling_constant = input2.inputs[0].get_tensor_value()
67+
elif input2.inputs[1].is_scalar():
68+
data = input2.inputs[0]
69+
scaling_constant = input2.inputs[1].get_tensor_value()
70+
else:
71+
logger.warning("Could not find scaling constant for dropout pattern rooted at %s. "
72+
"The pattern will not be replaced with an ONNX dropout node.", outputs.name)
73+
continue
74+
75+
#The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node
76+
if not np.allclose([1], [scaling_constant * (1 - ratio)]):
77+
logger.warning("Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout "
78+
"ratio %f. The pattern will not be replaced with an ONNX dropout node.",
79+
scaling_constant, outputs.name, ratio)
80+
continue
81+
82+
nodes_to_remove = [n for n in match.get_nodes() if n.name != input3.name]
83+
if not g.is_safe_to_remove_nodes(nodes_to_remove, [outputs.output[0]]):
84+
logger.warning("Nodes in dropout pattern rooted at %s cannot be removed because intermediate results "
85+
"of some nodes are referenced elsewhere in graph.", outputs.name)
86+
continue
87+
6788
op_name = utils.make_name("Dropout")
6889
out_name = utils.port_name(op_name)
6990
new_node = g.make_node(
7091
"Dropout",
71-
[data],
92+
inputs=[data.output[0]],
7293
outputs=[out_name],
7394
name=op_name,
7495
attr={"ratio": ratio},
75-
shapes=[g.get_shape(input2.input[0])],
76-
dtypes=[g.get_dtype(input2.input[0])]
96+
shapes=[g.get_shape(data.output[0])],
97+
dtypes=[g.get_dtype(data.output[0])]
7798
)
7899
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
79-
nodes_to_remove = []
80-
for node in match.get_nodes():
81-
if node.name != input3.name:
82-
nodes_to_remove.append(node)
83-
if g.safe_to_remove_nodes(nodes_to_remove):
84-
for n in nodes_to_remove:
85-
g.remove_node(n.name)
86-
else:
87-
logger.warning("Nodes replaced by dropout node cannot be removed because intermediate results are "
88-
"referenced elsewhere in graph")
100+
for n in nodes_to_remove:
101+
g.remove_node(n.name)
89102

90103
return ops

0 commit comments

Comments
 (0)