Skip to content

Commit d0a632c

Browse files
committed
tf2onnx converted model cannot go to onnx-experimental.py because of naming conflicting
1 parent 261818a commit d0a632c

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

tests/test_internals.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def test_insert_node1(self):
9292
result = onnx_to_graphviz(g)
9393
expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
9494
'n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
95-
'n4 [op_type=Add] n5 [op_type=Abs] graph_outputs_Identity__3 [op_type=Identity] ' \
95+
'n4 [op_type=Add] n5 [op_type=Abs] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
9696
'n6 [op_type=Identity] input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 ' \
97-
'n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 ' \
98-
'raw_output___2:0 -> n6 }'
97+
'n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 ' \
98+
'n5_raw_output___2:0 -> n6 }'
9999
self.assertEqual(expected, result)
100100

101101
def test_insert_node2(self):
@@ -107,9 +107,9 @@ def test_insert_node2(self):
107107
result = onnx_to_graphviz(g)
108108
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
109109
'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
110-
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
110+
'n5_graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
111111
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 ' \
112-
'n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 raw_output___2:0 -> n6 }'
112+
'n4:0 -> n5 n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 n5_raw_output___2:0 -> n6 }'
113113
self.assertEqual(expected, result)
114114

115115
def test_remove_input(self):
@@ -122,9 +122,9 @@ def test_remove_input(self):
122122
result = onnx_to_graphviz(g)
123123
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \
124124
'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
125-
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
125+
'n5_graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
126126
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 ' \
127-
'raw_output___2:0 -> graph_outputs_Identity__3 raw_output___2:0 -> n6 }'
127+
'n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 n5_raw_output___2:0 -> n6 }'
128128
self.assertEqual(expected, result)
129129

130130
def test_rewrite_subgraph(self):
@@ -150,9 +150,9 @@ def test_rewrite_subgraph(self):
150150
result = onnx_to_graphviz(g)
151151
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
152152
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
153-
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
153+
'n5_graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
154154
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 ' \
155-
'n3:0 -> ReplacedOp__5 ReplacedOp__5:0 -> graph_outputs_Identity__3 ' \
155+
'n3:0 -> ReplacedOp__5 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 ' \
156156
'ReplacedOp__5:0 -> n6 }'
157157
self.assertEqual(expected, result)
158158

tf2onnx/graph.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
340340
# add identity node after each output, in case it is renamed during conversion.
341341
for o in self.outputs:
342342
n = self.get_node_by_output_in_current_graph(o)
343-
new_output_name = port_name(utils.make_name("raw_output_"))
343+
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
344344
n_shapes = n.output_shapes
345345
n_dtypes = n.output_dtypes
346346
body_graphs = n.graph.contained_graphs.pop(n.name, None)
@@ -356,7 +356,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
356356
new_node.set_body_graph_as_attr(attr_name, body_graph)
357357

358358
self.replace_all_inputs(self.get_nodes(), o, new_output_name)
359-
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope="graph_outputs")
359+
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
360360
self.copy_shape(new_output_name, o)
361361
self.copy_dtype(new_output_name, o)
362362

@@ -1112,8 +1112,8 @@ def create_graph_from_onnx_graph(graph_proto):
11121112
output_names.append(n.name)
11131113

11141114
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, None, None, output_names)
1115-
GraphUtil._parse_graph_initializer(g, graph_proto)
1116-
GraphUtil._parse_graph_input(g, graph_proto)
1115+
const_nodes = GraphUtil._parse_graph_initializer(g, graph_proto)
1116+
GraphUtil._parse_graph_input(g, graph_proto, [n.name for n in const_nodes])
11171117

11181118
for n in g.get_nodes():
11191119
for attr_name, attr_val in n.attr.items():
@@ -1164,10 +1164,11 @@ def _parse_graph_initializer(g, graph_proto):
11641164
return const_nodes
11651165

11661166
@staticmethod
1167-
def _parse_graph_input(g, graph_proto):
1167+
def _parse_graph_input(g, graph_proto, const_node_names):
11681168
"""Get graph inputs not defined as initializers and put into Graph object."""
11691169
shapes, dtypes = GraphUtil._parse_shape_and_type_from_value_infos(graph_proto.input)
11701170
for name in shapes:
11711171
shape = shapes[name]
11721172
dtype = dtypes[name]
1173-
g.add_graph_input(name, dtype, shape)
1173+
if name not in const_node_names:
1174+
g.add_graph_input(name, dtype, shape)

tools/onnx-experiments.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def rewrite_constant_fold(g, ops):
8181
log.info("folding node type=%s, name=%s", op.type, op.name)
8282
if op.type == "Cast":
8383
dst = op.get_attr_int("to")
84-
np_type = dst
84+
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)
8585
val = np.cast[np_type](*inputs)
8686
elif op.type == "Transpose":
8787
perm = op.get_attr("perm").ints
@@ -113,8 +113,9 @@ def rewrite_constant_fold(g, ops):
113113
if consumers:
114114
for consumer in consumers:
115115
g.replace_input(consumer, old_output_name, new_output_name)
116-
for node in op.inputs:
117-
g.remove_node(node.name)
116+
for i, node in zip(op.input, op.inputs):
117+
if len(g.find_output_consumers(i)) == 1:
118+
g.remove_node(node.name)
118119
keep_looking = True
119120
except Exception as ex: # pylint: disable=broad-except
120121
tb = traceback.format_exc()

0 commit comments

Comments
 (0)