19
19
import tf2onnx
20
20
import tf2onnx .utils
21
21
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
22
- from tf2onnx .graph import Graph
22
+ from tf2onnx .graph import Graph , GraphUtil
23
23
24
24
# pylint: disable=missing-docstring
25
25
@@ -43,8 +43,8 @@ def onnx_to_graphviz(g):
43
43
44
44
def onnx_pretty (g , args = None ):
45
45
"""Onnx graph pretty print."""
46
- model_proto = g .make_model ("converted from {}" .format (args .input ))
47
- return helper .printable_graph (model_proto .graph )
46
+ graph_proto = g .make_model ("converted from {}" .format (args .input ))
47
+ return helper .printable_graph (graph_proto .graph )
48
48
49
49
50
50
class Tf2OnnxInternalTests (unittest .TestCase ):
@@ -73,60 +73,63 @@ def sample_net():
73
73
n5 = helper .make_node ("Abs" , ["n4:0" ], ["n5:0" ], name = "n5" )
74
74
n6 = helper .make_node ("Identity" , ["n5:0" ], ["n6:0" ], name = "n6" )
75
75
76
- model_proto = helper .make_graph (
76
+ graph_proto = helper .make_graph (
77
77
nodes = [n1 , n2 , n3 , n4 , n5 , n6 ],
78
78
name = "test" ,
79
79
inputs = [helper .make_tensor_value_info ("input" , TensorProto .FLOAT , [2 , 2 ])],
80
80
outputs = [helper .make_tensor_value_info ("n5:0" , TensorProto .FLOAT , [2 , 2 ])],
81
81
initializer = []
82
82
)
83
- return model_proto
83
+ return graph_proto
84
84
85
85
def test_insert_node1 (self ):
86
- model_proto = self .sample_net ()
87
- nodes = model_proto .node
88
- g = Graph (nodes , output_shapes = {}, dtypes = {})
86
+ graph_proto = self .sample_net ()
87
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
89
88
n2 = g .get_node_by_name ("n2" )
90
89
n7 = g .insert_new_node_on_input (n2 , "Abs" , "n1:0" , name = "n7" )
91
90
ops = g .get_nodes ()
92
91
ops .append (n7 )
93
92
g .topological_sort (ops )
94
93
result = onnx_to_graphviz (g )
95
- expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
96
- 'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
97
- 'input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
94
+ expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
95
+ 'n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
96
+ 'n4 [op_type=Add] n5 [op_type=Abs] graph_outputs_Identity__3 [op_type=Identity] ' \
97
+ 'n6 [op_type=Identity] input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 ' \
98
+ 'n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 ' \
99
+ 'raw_output___2:0 -> n6 }'
98
100
self .assertEqual (expected , result )
99
101
100
102
def test_insert_node2 (self ):
101
- model_proto = self .sample_net ()
102
- nodes = model_proto .node
103
- g = Graph (nodes , output_shapes = {}, dtypes = {})
103
+ graph_proto = self .sample_net ()
104
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
104
105
n7 = g .insert_new_node_on_output ("Abs" , "n1:0" , name = "n7" )
105
106
ops = g .get_nodes ()
106
107
ops .append (n7 )
107
108
g .topological_sort (ops )
108
109
result = onnx_to_graphviz (g )
109
- expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ' \
110
- 'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
111
- 'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
110
+ expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
111
+ 'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
112
+ 'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
113
+ 'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 ' \
114
+ 'n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 raw_output___2:0 -> n6 }'
112
115
self .assertEqual (expected , result )
113
116
114
117
def test_remove_input (self ):
115
- model_proto = self .sample_net ()
116
- nodes = model_proto .node
117
- g = Graph (nodes , output_shapes = {}, dtypes = {})
118
+ graph_proto = self .sample_net ()
119
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
118
120
n4 = g .get_node_by_name ("n4" )
119
121
g .remove_input (n4 , n4 .input [1 ])
120
122
result = onnx_to_graphviz (g )
121
123
expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \
122
- 'n5 [op_type=Abs] n6 [op_type=Identity] input -> n1 n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 ' \
123
- 'n4:0 -> n5 n5:0 -> n6 }'
124
+ 'n5 [op_type=Abs] n6 [op_type=Identity] graph_outputs_Identity__3 ' \
125
+ '[op_type=Identity] Placeholder__4 [op_type=Placeholder] input -> n1 ' \
126
+ 'n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 n4:0 -> n5 raw_output___2:0 -> n6 ' \
127
+ 'raw_output___2:0 -> graph_outputs_Identity__3 }'
124
128
self .assertEqual (expected , result )
125
129
126
130
def test_rewrite_subgraph (self ):
127
- model_proto = self .sample_net ()
128
- nodes = model_proto .node
129
- g = tf2onnx .graph .Graph (nodes , output_shapes = {}, dtypes = {})
131
+ graph_proto = self .sample_net ()
132
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
130
133
pattern = \
131
134
OpTypePattern ('Abs' , name = 'output' , inputs = [
132
135
OpTypePattern ('Add' , name = 'input' )
@@ -143,26 +146,28 @@ def test_rewrite_subgraph(self):
143
146
ops = g .replace_subgraph (ops , match , [], [output_node ], [], [new_node ])
144
147
g .topological_sort (ops )
145
148
result = onnx_to_graphviz (g )
146
- expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \
147
- 'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \
148
- 'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
149
+ expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
150
+ 'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
151
+ 'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
152
+ 'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 ' \
153
+ 'n3:0 -> ReplacedOp__5 ReplacedOp__5:0 -> graph_outputs_Identity__3 ' \
154
+ 'ReplacedOp__5:0 -> n6 }'
149
155
self .assertEqual (expected , result )
150
156
151
157
def test_match_flipped (self ):
152
158
n1 = helper .make_node ("Sub" , ["i1" , "i1" ], ["n1:0" ], name = "n1" )
153
159
n2 = helper .make_node ("Add" , ["i2" , "i2" ], ["n2:0" ], name = "n2" )
154
160
n3 = helper .make_node ("Mul" , ["n1:0" , "n2:0" ], ["n3:0" ], name = "n3" )
155
161
156
- model_proto = helper .make_graph (
162
+ graph_proto = helper .make_graph (
157
163
nodes = [n1 , n2 , n3 ],
158
164
name = "test" ,
159
165
inputs = [helper .make_tensor_value_info ("i1" , TensorProto .FLOAT , [2 , 2 ]),
160
166
helper .make_tensor_value_info ("i2" , TensorProto .FLOAT , [2 , 2 ])],
161
167
outputs = [helper .make_tensor_value_info ("n2:0" , TensorProto .FLOAT , [2 , 2 ])],
162
168
initializer = []
163
169
)
164
- nodes = model_proto .node
165
- g = tf2onnx .graph .Graph (nodes , output_shapes = {}, dtypes = {})
170
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
166
171
pattern = OpTypePattern ('Mul' , inputs = [
167
172
OpTypePattern ('Add' ),
168
173
OpTypePattern ('Sub' )
0 commit comments