Skip to content

Commit ee8ecf4

Browse files
committed
fix issues in test_internals.py and onnx-experiments.py
1 parent b9f9578 commit ee8ecf4

File tree

3 files changed

+42
-74
lines changed

3 files changed

+42
-74
lines changed

tests/test_internals.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tf2onnx
2020
import tf2onnx.utils
2121
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
22-
from tf2onnx.graph import Graph
22+
from tf2onnx.graph import Graph, GraphUtil
2323

2424
# pylint: disable=missing-docstring
2525

@@ -43,8 +43,8 @@ def onnx_to_graphviz(g):
4343

4444
def onnx_pretty(g, args=None):
4545
"""Onnx graph pretty print."""
46-
model_proto = g.make_model("converted from {}".format(args.input))
47-
return helper.printable_graph(model_proto.graph)
46+
graph_proto = g.make_model("converted from {}".format(args.input))
47+
return helper.printable_graph(graph_proto.graph)
4848

4949

5050
class Tf2OnnxInternalTests(unittest.TestCase):
@@ -73,60 +73,63 @@ def sample_net():
7373
n5 = helper.make_node("Abs", ["n4:0"], ["n5:0"], name="n5")
7474
n6 = helper.make_node("Identity", ["n5:0"], ["n6:0"], name="n6")
7575

76-
model_proto = helper.make_graph(
76+
graph_proto = helper.make_graph(
7777
nodes=[n1, n2, n3, n4, n5, n6],
7878
name="test",
7979
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 2])],
8080
outputs=[helper.make_tensor_value_info("n5:0", TensorProto.FLOAT, [2, 2])],
8181
initializer=[]
8282
)
83-
return model_proto
83+
return graph_proto
8484

8585
def test_insert_node1(self):
86-
model_proto = self.sample_net()
87-
nodes = model_proto.node
88-
g = Graph(nodes, output_shapes={}, dtypes={})
86+
graph_proto = self.sample_net()
87+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
8988
n2 = g.get_node_by_name("n2")
9089
n7 = g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7")
9190
ops = g.get_nodes()
9291
ops.append(n7)
9392
g.topological_sort(ops)
9493
result = onnx_to_graphviz(g)
95-
expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
96-
'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
97-
'input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
94+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
95+
'n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
96+
'n4 [op_type=Add] n5 [op_type=Abs] graph_outputs_Identity__3 [op_type=Identity] ' \
97+
'n6 [op_type=Identity] input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 ' \
98+
'n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 ' \
99+
'raw_output___2:0 -> n6 }'
98100
self.assertEqual(expected, result)
99101

100102
def test_insert_node2(self):
101-
model_proto = self.sample_net()
102-
nodes = model_proto.node
103-
g = Graph(nodes, output_shapes={}, dtypes={})
103+
graph_proto = self.sample_net()
104+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
104105
n7 = g.insert_new_node_on_output("Abs", "n1:0", name="n7")
105106
ops = g.get_nodes()
106107
ops.append(n7)
107108
g.topological_sort(ops)
108109
result = onnx_to_graphviz(g)
109-
expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ' \
110-
'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
111-
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
110+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
111+
'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
112+
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
113+
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 ' \
114+
'n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 raw_output___2:0 -> n6 }'
112115
self.assertEqual(expected, result)
113116

114117
def test_remove_input(self):
115-
model_proto = self.sample_net()
116-
nodes = model_proto.node
117-
g = Graph(nodes, output_shapes={}, dtypes={})
118+
graph_proto = self.sample_net()
119+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
118120
n4 = g.get_node_by_name("n4")
119121
g.remove_input(n4, n4.input[1])
120122
result = onnx_to_graphviz(g)
121123
expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \
122-
'n5 [op_type=Abs] n6 [op_type=Identity] input -> n1 n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 ' \
123-
'n4:0 -> n5 n5:0 -> n6 }'
124+
'n5 [op_type=Abs] n6 [op_type=Identity] graph_outputs_Identity__3 ' \
125+
'[op_type=Identity] Placeholder__4 [op_type=Placeholder] input -> n1 ' \
126+
'n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 n4:0 -> n5 raw_output___2:0 -> n6 ' \
127+
'raw_output___2:0 -> graph_outputs_Identity__3 }'
124128
self.assertEqual(expected, result)
125129

126130
def test_rewrite_subgraph(self):
127-
model_proto = self.sample_net()
128-
nodes = model_proto.node
129-
g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
131+
graph_proto = self.sample_net()
132+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
130133
pattern = \
131134
OpTypePattern('Abs', name='output', inputs=[
132135
OpTypePattern('Add', name='input')
@@ -143,26 +146,28 @@ def test_rewrite_subgraph(self):
143146
ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
144147
g.topological_sort(ops)
145148
result = onnx_to_graphviz(g)
146-
expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \
147-
'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \
148-
'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
149+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
150+
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
151+
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
152+
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 ' \
153+
'n3:0 -> ReplacedOp__5 ReplacedOp__5:0 -> graph_outputs_Identity__3 ' \
154+
'ReplacedOp__5:0 -> n6 }'
149155
self.assertEqual(expected, result)
150156

151157
def test_match_flipped(self):
152158
n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
153159
n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
154160
n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")
155161

156-
model_proto = helper.make_graph(
162+
graph_proto = helper.make_graph(
157163
nodes=[n1, n2, n3],
158164
name="test",
159165
inputs=[helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
160166
helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])],
161167
outputs=[helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2])],
162168
initializer=[]
163169
)
164-
nodes = model_proto.node
165-
g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
170+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
166171
pattern = OpTypePattern('Mul', inputs=[
167172
OpTypePattern('Add'),
168173
OpTypePattern('Sub')

tf2onnx/graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,10 @@ def create_graph_from_onnx_graph(graph_proto):
10411041
for n in graph_proto.node:
10421042
if n.op_type == "Constant":
10431043
n.op_type = "Const"
1044+
1045+
# some pytorch model had empty names - make one up
1046+
if not n.name:
1047+
n.name = utils.make_name("was_empty")
10441048
nodes_to_append.append(n)
10451049

10461050
output_names = []

tools/onnx-experiments.py

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from onnx import numpy_helper
2323

2424
import tf2onnx.utils
25-
from tf2onnx.graph import Graph
25+
from tf2onnx.graph import Graph, GraphUtil
2626

2727
logging.basicConfig(level=logging.INFO)
2828
log = logging.getLogger("onnx-experiments")
@@ -41,48 +41,7 @@ def load_graph(fname):
4141
with open(fname, "rb") as f:
4242
data = f.read()
4343
model_proto = onnx.ModelProto()
44-
model_proto.ParseFromString(data)
45-
onnx_nodes = model_proto.graph.node
46-
output_names = []
47-
48-
# some pytorch model had empty names - make one up
49-
for node in onnx_nodes:
50-
if not node.name:
51-
node.name = tf2onnx.utils.make_name("was_empty")
52-
53-
g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names)
54-
for i in model_proto.graph.initializer:
55-
v = numpy_helper.to_array(i)
56-
name = i.name
57-
g.initializers[name] = i
58-
dtype = i.data_type
59-
g.set_dtype(name, dtype)
60-
g.set_shape(name, v.shape)
61-
for i in model_proto.graph.input:
62-
name = i.name
63-
if name in g.initializers:
64-
# ignore if it is not a model input
65-
continue
66-
shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
67-
for j in i.type.tensor_type.shape.dim]
68-
dtype = i.type.tensor_type.elem_type
69-
g.set_dtype(name, dtype)
70-
g.set_shape(name, shape)
71-
g.add_graph_input(name, dtype, shape)
72-
for i in model_proto.graph.output:
73-
name = i.name
74-
shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
75-
for j in i.type.tensor_type.shape.dim]
76-
dtype = i.type.tensor_type.elem_type
77-
g.set_dtype(name, dtype)
78-
g.set_shape(name, shape)
79-
output_names.append(name)
80-
81-
# TODO: this is a hack in case a output name does not follow tensorflow convention
82-
for node in g.get_nodes():
83-
for name in node.output:
84-
g._nodes_by_name[name] = node # pylint: disable=protected-access
85-
44+
g = GraphUtil.create_graph_from_onnx_model(model_proto)
8645
return g, model_proto.producer_name
8746

8847

0 commit comments

Comments
 (0)