Skip to content

Commit d9e18e4

Browse files
delete nodes without dependency in each rewriter. check graph integrity after rewriting in verbose mode.
1 parent 29ab979 commit d9e18e4

File tree

7 files changed

+54
-21
lines changed

7 files changed

+54
-21
lines changed

tf2onnx/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,6 @@
3737

3838
# Environment variables
3939
ENV_TF2ONNX_DEBUG_MODE = "TF2ONNX_DEBUG_MODE"
40+
41+
# Logging level
42+
VERBOSE = 15

tf2onnx/graph.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,18 @@ def reset_nodes(self, ops):
575575
self._dtypes = remained_dtypes
576576
self._output_shapes = remained_shapes
577577

578+
def check_integrity(self):
579+
"""
580+
Check graph integrity. Every node's input needs to associate with a node.
581+
Return broken outputs.
582+
"""
583+
broken_outputs = set()
584+
for node in self.get_nodes():
585+
for inp in node.input:
586+
if self.get_node_by_output(inp) is None:
587+
broken_outputs.add(inp)
588+
return list(broken_outputs)
589+
578590
def update_node_shape_dtype(self, node, override=False):
579591
"""Try the best to infer shapes and dtypes for outputs of the node,
580592
by default, we respect TF shapes and dtypes.
@@ -591,7 +603,7 @@ def update_node_shape_dtype(self, node, override=False):
591603
initializers = []
592604
for i, inp in enumerate(node.inputs):
593605
if inp is None:
594-
if logger.isEnabledFor(logging.INFO):
606+
if logger.isEnabledFor(constants.VERBOSE):
595607
logger.warning(
596608
"[%s] infer a inexistent node: [%s], please check the code",
597609
node.name, node.input[i]
@@ -1169,6 +1181,20 @@ def delete_unused_nodes(self, outputs_name):
11691181
body_graph.delete_unused_nodes(body_graph.outputs)
11701182
self.reset_nodes(related_nodes)
11711183

1184+
def delete_nodes_without_dependency(self, to_delete):
1185+
"""Delete nodes in `to_delete` without third-party dependency."""
1186+
for n in to_delete:
1187+
can_delete = True
1188+
for out in n.output:
1189+
if not can_delete:
1190+
break
1191+
for consumer in self.find_output_consumers(out):
1192+
if consumer not in to_delete:
1193+
can_delete = False
1194+
break
1195+
if can_delete:
1196+
self.remove_node(n.name)
1197+
11721198

11731199
class GraphUtil(object):
11741200
"""Utilities for Graph manipulation."""

tf2onnx/rewriter/leakyrelu_rewriter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def rewrite_leakyrelu(g, ops):
4242
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
4343
ops.append(leakyrelu)
4444
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
45+
to_delete = [max_node, mul_node]
46+
g.delete_nodes_without_dependency(to_delete)
4547

4648
return ops
4749

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def rewrite_random_uniform_fold_const(g, ops):
6363
to_delete = list(set(match.get_nodes()))
6464
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
6565
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
66-
for n in to_delete:
67-
g.remove_node(n.name)
66+
g.delete_nodes_without_dependency(to_delete)
6867

6968
return ops
7069

tf2onnx/rewriter/thresholded_relu_rewriter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def rewrite_thresholded_relu(g, ops):
3434
greater_input_node = match.get_op('greater_input')
3535
mul_node = match.get_op("mul")
3636
mul_input_node = match.get_op('mul_input')
37+
cast_node = match.get_op('cast')
3738

3839
greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
3940
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
@@ -42,6 +43,7 @@ def rewrite_thresholded_relu(g, ops):
4243
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
4344
shapes=[g.get_shape(mul_node.output[0])],
4445
dtypes=[g.get_dtype(mul_node.output[0])])
45-
ops.append(thresholded_relu)
4646
g.replace_all_inputs(ops, mul_node.output[0], thresholded_relu.output[0])
47+
to_delete = [cast_node, mul_node]
48+
g.delete_nodes_without_dependency(to_delete)
4749
return ops

tf2onnx/tfonnx.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,8 @@ def rewrite_transpose(g, ops):
141141
dims = [i for i in range(len(shape) - 1, -1, -1)]
142142
output.set_attr("perm", dims)
143143
g.remove_input(output, output.input[1])
144-
for n in set(match.get_nodes()):
145-
if n != output:
146-
g.remove_node(n.name)
144+
to_delete = [n for n in set(match.get_nodes()) if n != output]
145+
g.delete_nodes_without_dependency(to_delete)
147146
return ops
148147

149148

@@ -175,8 +174,7 @@ def rewrite_random_normal(g, ops):
175174
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})
176175

177176
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
178-
for n in set(match.get_nodes()):
179-
g.remove_node(n.name)
177+
g.delete_nodes_without_dependency(set(match.get_nodes()))
180178
return ops
181179

182180

@@ -208,8 +206,7 @@ def rewrite_dropout(g, ops):
208206
dtypes=[g.get_dtype(inputs2.input[0])]
209207
)
210208
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
211-
for n in set(match.get_nodes()):
212-
g.remove_node(n.name)
209+
g.delete_nodes_without_dependency(set(match.get_nodes()))
213210

214211
# remove dropout if its ratio is 1.0
215212
for node in g.get_nodes():
@@ -294,10 +291,8 @@ def rewrite_flatten(g, ops):
294291

295292
g.set_shape(out_name, input_shape[:-2] + [new_dim])
296293
g.replace_all_inputs(ops, reshape_node.output[0], out_name)
297-
298-
for n in set(match.get_nodes()):
299-
if n != input_node:
300-
g.remove_node(n.name)
294+
to_delete = [n for n in set(match.get_nodes()) if n != input_node]
295+
g.delete_nodes_without_dependency(to_delete)
301296

302297
return ops
303298

@@ -654,6 +649,14 @@ def run_rewriters(g, funcs, continue_on_error):
654649
else:
655650
raise ex
656651

652+
if logger.isEnabledFor(constants.VERBOSE):
653+
broken_outputs = g.check_integrity()
654+
if broken_outputs:
655+
logging.error(
656+
"After rewriter %s, graph breaks at outputs %s",
657+
func.__name__, broken_outputs
658+
)
659+
657660
if g.contained_graphs:
658661
for dict_val in g.contained_graphs.values():
659662
for attr_name, b_g in dict_val.items():

tf2onnx/verbose_logging.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515

1616
from . import constants
1717

18-
VERBOSE = 15
19-
20-
_logging.addLevelName(VERBOSE, "VERBOSE")
18+
_logging.addLevelName(constants.VERBOSE, "VERBOSE")
2119

2220

2321
def _verbose(self, message, *args, **kwargs):
24-
if self.isEnabledFor(VERBOSE):
25-
self._log(VERBOSE, message, args, **kwargs) # pylint: disable=protected-access
22+
if self.isEnabledFor(constants.VERBOSE):
23+
self._log(constants.VERBOSE, message, args, **kwargs) # pylint: disable=protected-access
2624

2725

2826
def getLogger(name=None): # pylint: disable=invalid-name, function-redefined
@@ -47,7 +45,7 @@ def basicConfig(**kwargs): # pylint: disable=invalid-name, function-redefined
4745
set_tf_verbosity(_logging.getLogger().getEffectiveLevel())
4846

4947

50-
_LOG_LEVELS = [FATAL, ERROR, WARNING, INFO, VERBOSE, DEBUG]
48+
_LOG_LEVELS = [FATAL, ERROR, WARNING, INFO, constants.VERBOSE, DEBUG]
5149

5250

5351
def get_verbosity_level(verbosity, base_level=INFO):

0 commit comments

Comments
 (0)