Skip to content

Commit 73afa70

Browse files
authored
Merge pull request #302 from pengwa/const_placeholder_lazy_into_initializer
remain Const/Placeholder as it is, util we make_graph at last
2 parents d39427a + 05d11d1 commit 73afa70

20 files changed

+306
-391
lines changed

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,7 @@ def test_sparse_softmax_cross_entropy_with_logits_large_class(self):
15381538
res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logits)
15391539
_ = tf.identity(res, name=_TFOUTPUT)
15401540

1541-
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val})
1541+
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val}, rtol=1e-6)
15421542

15431543
@unittest.skipIf(BACKEND in ["onnxruntime"], "onnxruntime Slice did not supported BOOL.")
15441544
def test_matrix_band_part(self):

tests/test_graph.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_abs(self):
115115
x_ = tf.abs(x)
116116
_ = tf.identity(x_, name="output")
117117
g = process_tf_graph(sess.graph)
118-
self.assertEqual('digraph { Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
118+
self.assertEqual('digraph { input [op_type=Placeholder shape="[2, 3]"]'\
119+
' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
119120
onnx_to_graphviz(g))
120121

121122
def test_randomuniform(self):
@@ -154,9 +155,11 @@ def test_dropout(self):
154155
_ = tf.identity(x_, name="output")
155156
g = process_tf_graph(sess.graph)
156157
actual = onnx_to_graphviz(g)
157-
expected = 'digraph { Add [op_type=Add] Dropout__3 [op_type=Dropout] output1 [op_type=Identity] ' \
158-
'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> ' \
159-
'Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 output1:0 -> output2 output2:0 -> output }'
158+
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
159+
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
160+
'output1 [op_type=Identity] output2 [op_type=Identity] output [op_type=Identity] ' \
161+
'input1:0 -> Add input2:0 -> Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 ' \
162+
'output1:0 -> output2 output2:0 -> output }'
160163
self.assertEqual(expected, actual)
161164

162165
def test_add(self):
@@ -167,8 +170,8 @@ def test_add(self):
167170
_ = tf.identity(x_, name="output")
168171
g = process_tf_graph(sess.graph)
169172
self.assertEqual(
170-
'digraph { Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> '
171-
'Add Add:0 -> output }',
173+
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
174+
'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }',
172175
onnx_to_graphviz(g))
173176

174177
def test_squareddifference(self):
@@ -179,7 +182,8 @@ def test_squareddifference(self):
179182
_ = tf.identity(x_, name="output")
180183
g = process_tf_graph(sess.graph)
181184
self.assertEqual(
182-
'digraph { SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
185+
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
186+
'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
183187
'output [op_type=Identity] input1:0 -> SquaredDifference input2:0 -> SquaredDifference '
184188
'SquaredDifference:0 -> SquaredDifference__2 SquaredDifference:0 -> SquaredDifference__2 '
185189
'SquaredDifference__2:0 -> output }',
@@ -192,7 +196,8 @@ def test_reducesum(self):
192196
_ = tf.identity(x_, name="output")
193197
g = process_tf_graph(sess.graph)
194198
self.assertEqual(
195-
'digraph { Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
199+
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
200+
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
196201
onnx_to_graphviz(g))
197202

198203
def test_argminmax(self):
@@ -202,7 +207,8 @@ def test_argminmax(self):
202207
_ = tf.identity(x_, name="output")
203208
g = process_tf_graph(sess.graph)
204209
self.assertEqual(
205-
'digraph { ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
210+
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] ' \
211+
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
206212
onnx_to_graphviz(g))
207213

208214
def test_rsqrt(self):
@@ -212,8 +218,9 @@ def test_rsqrt(self):
212218
_ = tf.identity(x_, name="output")
213219
g = process_tf_graph(sess.graph)
214220
self.assertEqual(
215-
'digraph { Rsqrt [op_type=Sqrt] Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] '
216-
'input1:0 -> Rsqrt Rsqrt:0 -> Rsqrt__2 Rsqrt__2:0 -> output }',
221+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
222+
'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
223+
'Rsqrt:0 -> Rsqrt__2 Rsqrt__2:0 -> output }',
217224
onnx_to_graphviz(g))
218225

219226
def test_relu6(self):
@@ -223,7 +230,9 @@ def test_relu6(self):
223230
_ = tf.identity(x_, name="output")
224231
g = process_tf_graph(sess.graph)
225232
self.assertEqual(
226-
'digraph { Relu6 [op_type=Max] Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
233+
'digraph { Relu6__3 [op_type=Const] Relu6__2 [op_type=Const] '
234+
'input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Max] '
235+
'Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
227236
'Relu6__2 -> Relu6 Relu6:0 -> Relu6__4 Relu6__3 -> Relu6__4 Relu6__4:0 -> output }',
228237
onnx_to_graphviz(g))
229238

@@ -251,10 +260,12 @@ def test_conv2d(self):
251260

252261
g = process_tf_graph(sess.graph)
253262
self.assertEqual(
254-
'digraph { Conv2D__2 [op_type=Transpose] kernel [op_type=Reshape] Conv2D__3 [op_type=Transpose] '
255-
'Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] output [op_type=Identity] '
256-
'input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel kernel:0 -> Conv2D__3 '
257-
'Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }',
263+
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__2 [op_type=Transpose] '
264+
'"kernel/shape" [op_type=Const] k [op_type=Const] kernel [op_type=Reshape] '
265+
'Conv2D__3 [op_type=Transpose] Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] '
266+
'output [op_type=Identity] input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel '
267+
'kernel:0 -> Conv2D__3 Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D '
268+
'Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }',
258269
onnx_to_graphviz(g))
259270

260271
def test_squeeze(self):
@@ -264,8 +275,8 @@ def test_squeeze(self):
264275
_ = tf.identity(x_, name="output")
265276
g = process_tf_graph(sess.graph)
266277
self.assertEqual(
267-
'digraph { Squeeze [op_type=Squeeze] output [op_type=Identity] input1:0 -> Squeeze '
268-
'Squeeze:0 -> output }',
278+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '\
279+
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
269280
onnx_to_graphviz(g))
270281

271282
def test_cast(self):
@@ -275,7 +286,8 @@ def test_cast(self):
275286
_ = tf.identity(x_, name="output")
276287
g = process_tf_graph(sess.graph)
277288
self.assertEqual(
278-
'digraph { Cast [op_type=Cast] output [op_type=Identity] input1:0 -> Cast Cast:0 -> output }',
289+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '\
290+
'input1:0 -> Cast Cast:0 -> output }',
279291
onnx_to_graphviz(g))
280292

281293
def test_reshape(self):
@@ -285,7 +297,8 @@ def test_reshape(self):
285297
_ = tf.identity(x_, name="output")
286298
g = process_tf_graph(sess.graph)
287299
self.assertEqual(
288-
'digraph { Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
300+
'digraph { "Reshape/shape" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
301+
'Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
289302
'"Reshape/shape":0 -> Reshape Reshape:0 -> output }',
290303
onnx_to_graphviz(g))
291304

@@ -308,8 +321,8 @@ def rewrite_test(g, ops):
308321
_ = tf.identity(x_, name="output")
309322
g = process_tf_graph(sess.graph, custom_rewriter=[rewrite_test])
310323
self.assertEqual(
311-
'digraph { Add [op_type=Mul] output [op_type=Identity] input1:0 -> '
312-
'Add input1:0 -> Add Add:0 -> output }',
324+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325+
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
313326
onnx_to_graphviz(g))
314327

315328
def test_custom_op(self):
@@ -333,7 +346,8 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
333346
custom_op_handlers={"Print": print_handler},
334347
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
335348
self.assertEqual(
336-
'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
349+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
350+
'output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
337351
onnx_to_graphviz(g))
338352

339353

tests/test_internals.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tf2onnx
2020
import tf2onnx.utils
2121
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
22-
from tf2onnx.graph import Graph
22+
from tf2onnx.graph import GraphUtil
2323

2424
# pylint: disable=missing-docstring
2525

@@ -43,8 +43,8 @@ def onnx_to_graphviz(g):
4343

4444
def onnx_pretty(g, args=None):
4545
"""Onnx graph pretty print."""
46-
model_proto = g.make_model("converted from {}".format(args.input))
47-
return helper.printable_graph(model_proto.graph)
46+
graph_proto = g.make_model("converted from {}".format(args.input))
47+
return helper.printable_graph(graph_proto.graph)
4848

4949

5050
class Tf2OnnxInternalTests(unittest.TestCase):
@@ -73,60 +73,63 @@ def sample_net():
7373
n5 = helper.make_node("Abs", ["n4:0"], ["n5:0"], name="n5")
7474
n6 = helper.make_node("Identity", ["n5:0"], ["n6:0"], name="n6")
7575

76-
model_proto = helper.make_graph(
76+
graph_proto = helper.make_graph(
7777
nodes=[n1, n2, n3, n4, n5, n6],
7878
name="test",
7979
inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 2])],
8080
outputs=[helper.make_tensor_value_info("n5:0", TensorProto.FLOAT, [2, 2])],
8181
initializer=[]
8282
)
83-
return model_proto
83+
return graph_proto
8484

8585
def test_insert_node1(self):
86-
model_proto = self.sample_net()
87-
nodes = model_proto.node
88-
g = Graph(nodes, output_shapes={}, dtypes={})
86+
graph_proto = self.sample_net()
87+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
8988
n2 = g.get_node_by_name("n2")
9089
n7 = g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7")
9190
ops = g.get_nodes()
9291
ops.append(n7)
9392
g.topological_sort(ops)
9493
result = onnx_to_graphviz(g)
95-
expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
96-
'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
97-
'input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
94+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
95+
'n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
96+
'n4 [op_type=Add] n5 [op_type=Abs] graph_outputs_Identity__3 [op_type=Identity] ' \
97+
'n6 [op_type=Identity] input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 ' \
98+
'n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 ' \
99+
'raw_output___2:0 -> n6 }'
98100
self.assertEqual(expected, result)
99101

100102
def test_insert_node2(self):
101-
model_proto = self.sample_net()
102-
nodes = model_proto.node
103-
g = Graph(nodes, output_shapes={}, dtypes={})
103+
graph_proto = self.sample_net()
104+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
104105
n7 = g.insert_new_node_on_output("Abs", "n1:0", name="n7")
105106
ops = g.get_nodes()
106107
ops.append(n7)
107108
g.topological_sort(ops)
108109
result = onnx_to_graphviz(g)
109-
expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ' \
110-
'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
111-
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
110+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
111+
'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
112+
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
113+
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 ' \
114+
'n4:0 -> n5 raw_output___2:0 -> graph_outputs_Identity__3 raw_output___2:0 -> n6 }'
112115
self.assertEqual(expected, result)
113116

114117
def test_remove_input(self):
115-
model_proto = self.sample_net()
116-
nodes = model_proto.node
117-
g = Graph(nodes, output_shapes={}, dtypes={})
118+
graph_proto = self.sample_net()
119+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
118120
n4 = g.get_node_by_name("n4")
119121
g.remove_input(n4, n4.input[1])
120122
result = onnx_to_graphviz(g)
121123
expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \
122-
'n5 [op_type=Abs] n6 [op_type=Identity] input -> n1 n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 ' \
123-
'n4:0 -> n5 n5:0 -> n6 }'
124+
'n5 [op_type=Abs] n6 [op_type=Identity] graph_outputs_Identity__3 ' \
125+
'[op_type=Identity] Placeholder__4 [op_type=Placeholder] input -> n1 ' \
126+
'n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 n4:0 -> n5 raw_output___2:0 -> n6 ' \
127+
'raw_output___2:0 -> graph_outputs_Identity__3 }'
124128
self.assertEqual(expected, result)
125129

126130
def test_rewrite_subgraph(self):
127-
model_proto = self.sample_net()
128-
nodes = model_proto.node
129-
g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
131+
graph_proto = self.sample_net()
132+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
130133
pattern = \
131134
OpTypePattern('Abs', name='output', inputs=[
132135
OpTypePattern('Add', name='input')
@@ -143,26 +146,28 @@ def test_rewrite_subgraph(self):
143146
ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
144147
g.topological_sort(ops)
145148
result = onnx_to_graphviz(g)
146-
expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \
147-
'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \
148-
'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
149+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
150+
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
151+
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
152+
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 ' \
153+
'n3:0 -> ReplacedOp__5 ReplacedOp__5:0 -> graph_outputs_Identity__3 ' \
154+
'ReplacedOp__5:0 -> n6 }'
149155
self.assertEqual(expected, result)
150156

151157
def test_match_flipped(self):
152158
n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
153159
n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
154160
n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")
155161

156-
model_proto = helper.make_graph(
162+
graph_proto = helper.make_graph(
157163
nodes=[n1, n2, n3],
158164
name="test",
159165
inputs=[helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
160166
helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])],
161167
outputs=[helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2])],
162168
initializer=[]
163169
)
164-
nodes = model_proto.node
165-
g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
170+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
166171
pattern = OpTypePattern('Mul', inputs=[
167172
OpTypePattern('Add'),
168173
OpTypePattern('Sub')

tf2onnx/function/gathernd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
2323
trip_node = ctx.make_node("Size", [index.output[0]])
2424
nodes.append(trip_node)
2525
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
26+
nodes.append(cond_const)
2627
trip_name = utils.make_name("i")
2728
cond_name = utils.make_name("cond")
2829
cond_out_name = utils.make_name("cond_out")
@@ -78,6 +79,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
7879
# for (int i=0; i<outter_shape; i++) inner_loop(params, flatten_indices[i])
7980
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
8081
dummy_const = ctx.make_const(utils.make_name("dummy"), np.ones((), dtype=np.int64))
82+
nodes.extend([cond_const, dummy_const])
8183

8284
# body graph creation
8385
g = ctx.create_new_graph_with_same_config()
@@ -143,6 +145,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
143145
outputs=[output])
144146
nodes.extend([indices_outter_shape,
145147
inner_loop_shape,
148+
one_const,
146149
inner_loop_shape_,
147150
output_inner_shape,
148151
output_shape_,

0 commit comments

Comments
 (0)