Skip to content

Commit b9f9578

Browse files
committed
remain Const/Placeholder as it is, util we make_graph at last
1 parent d39427a commit b9f9578

18 files changed

+263
-313
lines changed

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,7 @@ def test_sparse_softmax_cross_entropy_with_logits_large_class(self):
15381538
res = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logits)
15391539
_ = tf.identity(res, name=_TFOUTPUT)
15401540

1541-
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val})
1541+
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val}, rtol=1e-6)
15421542

15431543
@unittest.skipIf(BACKEND in ["onnxruntime"], "onnxruntime Slice did not supported BOOL.")
15441544
def test_matrix_band_part(self):

tests/test_graph.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_abs(self):
115115
x_ = tf.abs(x)
116116
_ = tf.identity(x_, name="output")
117117
g = process_tf_graph(sess.graph)
118-
self.assertEqual('digraph { Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
118+
self.assertEqual('digraph { input [op_type=Placeholder shape="[2, 3]"]'\
119+
' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
119120
onnx_to_graphviz(g))
120121

121122
def test_randomuniform(self):
@@ -154,9 +155,11 @@ def test_dropout(self):
154155
_ = tf.identity(x_, name="output")
155156
g = process_tf_graph(sess.graph)
156157
actual = onnx_to_graphviz(g)
157-
expected = 'digraph { Add [op_type=Add] Dropout__3 [op_type=Dropout] output1 [op_type=Identity] ' \
158-
'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> ' \
159-
'Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 output1:0 -> output2 output2:0 -> output }'
158+
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
159+
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
160+
'output1 [op_type=Identity] output2 [op_type=Identity] output [op_type=Identity] ' \
161+
'input1:0 -> Add input2:0 -> Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 ' \
162+
'output1:0 -> output2 output2:0 -> output }'
160163
self.assertEqual(expected, actual)
161164

162165
def test_add(self):
@@ -167,8 +170,8 @@ def test_add(self):
167170
_ = tf.identity(x_, name="output")
168171
g = process_tf_graph(sess.graph)
169172
self.assertEqual(
170-
'digraph { Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> '
171-
'Add Add:0 -> output }',
173+
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
174+
'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }',
172175
onnx_to_graphviz(g))
173176

174177
def test_squareddifference(self):
@@ -179,7 +182,8 @@ def test_squareddifference(self):
179182
_ = tf.identity(x_, name="output")
180183
g = process_tf_graph(sess.graph)
181184
self.assertEqual(
182-
'digraph { SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
185+
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
186+
'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
183187
'output [op_type=Identity] input1:0 -> SquaredDifference input2:0 -> SquaredDifference '
184188
'SquaredDifference:0 -> SquaredDifference__2 SquaredDifference:0 -> SquaredDifference__2 '
185189
'SquaredDifference__2:0 -> output }',
@@ -192,7 +196,8 @@ def test_reducesum(self):
192196
_ = tf.identity(x_, name="output")
193197
g = process_tf_graph(sess.graph)
194198
self.assertEqual(
195-
'digraph { Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
199+
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
200+
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
196201
onnx_to_graphviz(g))
197202

198203
def test_argminmax(self):
@@ -202,7 +207,8 @@ def test_argminmax(self):
202207
_ = tf.identity(x_, name="output")
203208
g = process_tf_graph(sess.graph)
204209
self.assertEqual(
205-
'digraph { ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
210+
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] ' \
211+
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
206212
onnx_to_graphviz(g))
207213

208214
def test_rsqrt(self):
@@ -212,8 +218,9 @@ def test_rsqrt(self):
212218
_ = tf.identity(x_, name="output")
213219
g = process_tf_graph(sess.graph)
214220
self.assertEqual(
215-
'digraph { Rsqrt [op_type=Sqrt] Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] '
216-
'input1:0 -> Rsqrt Rsqrt:0 -> Rsqrt__2 Rsqrt__2:0 -> output }',
221+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
222+
'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
223+
'Rsqrt:0 -> Rsqrt__2 Rsqrt__2:0 -> output }',
217224
onnx_to_graphviz(g))
218225

219226
def test_relu6(self):
@@ -223,7 +230,9 @@ def test_relu6(self):
223230
_ = tf.identity(x_, name="output")
224231
g = process_tf_graph(sess.graph)
225232
self.assertEqual(
226-
'digraph { Relu6 [op_type=Max] Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
233+
'digraph { Relu6__3 [op_type=Const] Relu6__2 [op_type=Const] '
234+
'input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Max] '
235+
'Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
227236
'Relu6__2 -> Relu6 Relu6:0 -> Relu6__4 Relu6__3 -> Relu6__4 Relu6__4:0 -> output }',
228237
onnx_to_graphviz(g))
229238

@@ -251,10 +260,12 @@ def test_conv2d(self):
251260

252261
g = process_tf_graph(sess.graph)
253262
self.assertEqual(
254-
'digraph { Conv2D__2 [op_type=Transpose] kernel [op_type=Reshape] Conv2D__3 [op_type=Transpose] '
255-
'Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] output [op_type=Identity] '
256-
'input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel kernel:0 -> Conv2D__3 '
257-
'Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }',
263+
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__2 [op_type=Transpose] '
264+
'"kernel/shape" [op_type=Const] k [op_type=Const] kernel [op_type=Reshape] '
265+
'Conv2D__3 [op_type=Transpose] Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] '
266+
'output [op_type=Identity] input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel '
267+
'kernel:0 -> Conv2D__3 Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D '
268+
'Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }',
258269
onnx_to_graphviz(g))
259270

260271
def test_squeeze(self):
@@ -264,8 +275,8 @@ def test_squeeze(self):
264275
_ = tf.identity(x_, name="output")
265276
g = process_tf_graph(sess.graph)
266277
self.assertEqual(
267-
'digraph { Squeeze [op_type=Squeeze] output [op_type=Identity] input1:0 -> Squeeze '
268-
'Squeeze:0 -> output }',
278+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '\
279+
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
269280
onnx_to_graphviz(g))
270281

271282
def test_cast(self):
@@ -275,7 +286,8 @@ def test_cast(self):
275286
_ = tf.identity(x_, name="output")
276287
g = process_tf_graph(sess.graph)
277288
self.assertEqual(
278-
'digraph { Cast [op_type=Cast] output [op_type=Identity] input1:0 -> Cast Cast:0 -> output }',
289+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '\
290+
'input1:0 -> Cast Cast:0 -> output }',
279291
onnx_to_graphviz(g))
280292

281293
def test_reshape(self):
@@ -285,7 +297,8 @@ def test_reshape(self):
285297
_ = tf.identity(x_, name="output")
286298
g = process_tf_graph(sess.graph)
287299
self.assertEqual(
288-
'digraph { Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
300+
'digraph { "Reshape/shape" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
301+
'Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
289302
'"Reshape/shape":0 -> Reshape Reshape:0 -> output }',
290303
onnx_to_graphviz(g))
291304

@@ -308,8 +321,8 @@ def rewrite_test(g, ops):
308321
_ = tf.identity(x_, name="output")
309322
g = process_tf_graph(sess.graph, custom_rewriter=[rewrite_test])
310323
self.assertEqual(
311-
'digraph { Add [op_type=Mul] output [op_type=Identity] input1:0 -> '
312-
'Add input1:0 -> Add Add:0 -> output }',
324+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325+
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
313326
onnx_to_graphviz(g))
314327

315328
def test_custom_op(self):
@@ -333,7 +346,8 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
333346
custom_op_handlers={"Print": print_handler},
334347
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
335348
self.assertEqual(
336-
'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
349+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
350+
'output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
337351
onnx_to_graphviz(g))
338352

339353

tf2onnx/function/gathernd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
2323
trip_node = ctx.make_node("Size", [index.output[0]])
2424
nodes.append(trip_node)
2525
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
26+
nodes.append(cond_const)
2627
trip_name = utils.make_name("i")
2728
cond_name = utils.make_name("cond")
2829
cond_out_name = utils.make_name("cond_out")
@@ -78,6 +79,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
7879
# for (int i=0; i<outter_shape; i++) inner_loop(params, flatten_indices[i])
7980
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
8081
dummy_const = ctx.make_const(utils.make_name("dummy"), np.ones((), dtype=np.int64))
82+
nodes.extend([cond_const, dummy_const])
8183

8284
# body graph creation
8385
g = ctx.create_new_graph_with_same_config()
@@ -143,6 +145,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
143145
outputs=[output])
144146
nodes.extend([indices_outter_shape,
145147
inner_loop_shape,
148+
one_const,
146149
inner_loop_shape_,
147150
output_inner_shape,
148151
output_shape_,

tf2onnx/function/matrixbandpart.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def matrixbandpart_op(ctx, node, name, args):
2525
# no need to worry about the dtype, because bool type is needed as Xor only support bool
2626
node_name = utils.make_name("const_zero")
2727
const_zero = ctx.make_const(name=node_name, np_val=np.array([0]).astype(np.int32))
28+
nodes.append(const_zero)
2829
first_col_or_row = ctx.make_node(op_type="Gather", inputs=[node.input[0], const_zero.output[0]],
2930
attr={"axis": axis})
3031
nodes.append(first_col_or_row)
@@ -41,6 +42,7 @@ def matrixbandpart_op(ctx, node, name, args):
4142
g = ctx.create_new_graph_with_same_config()
4243
node_name = utils.make_name("const_zero_bool")
4344
const_zero_bool = ctx.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
45+
nodes.append(const_zero_bool)
4446
ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
4547

4648
# shift right the line and add zero at the left.
@@ -67,14 +69,16 @@ def matrixbandpart_op(ctx, node, name, args):
6769
nodes.append(shape)
6870
node_name = utils.make_name("line_num_index")
6971
col_or_row_num_index = ctx.make_const(name=node_name, np_val=np.array(axis).astype(np.int32))
72+
nodes.append(col_or_row_num_index)
7073
line_num = ctx.make_node(op_type="Gather", inputs=[shape.output[0], col_or_row_num_index.output[0]])
7174
nodes.append(line_num)
7275
trip_cnt = line_num.output[0]
7376
node_name = utils.make_name("true")
74-
cond = ctx.make_const(name=node_name, np_val=np.array(1).astype(np.bool)).output[0]
77+
cond = ctx.make_const(name=node_name, np_val=np.array(1).astype(np.bool))
78+
nodes.append(cond)
7579
col_init = one_line.output[0]
7680

77-
loop_node = ctx.make_node(op_type="Loop", inputs=[trip_cnt, cond, col_init], output_count=2)
81+
loop_node = ctx.make_node(op_type="Loop", inputs=[trip_cnt, cond.output[0], col_init], output_count=2)
7882
loop_node.set_body_graph_as_attr("body", g)
7983
nodes.append(loop_node)
8084
# convert generated mask matrix from bool to right shape and data type

tf2onnx/function/range.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def make_range_const(ctx, start, limit, delta, output, scope_name, dtype):
2121
delta = ctx.get_node_by_output(delta).get_tensor_value(as_list=False)
2222
val = np.arange(start, limit, delta, dtype=start.dtype)
2323
const_range = ctx.make_const(base_name, val)
24-
return ctx.make_node("Identity", [const_range.output[0]], dtypes=[dtype], outputs=[output])
24+
return [ctx.make_node("Identity", [const_range.output[0]], dtypes=[dtype], outputs=[output]),
25+
const_range]
2526

2627

2728
def make_range_non_const(ctx, start, limit, delta, output, scope_name, dtype):
@@ -65,7 +66,7 @@ def make_range_non_const(ctx, start, limit, delta, output, scope_name, dtype):
6566
# cond
6667
# Use initializer here since Constant OP before opset 9 does not support bool type
6768
cond_name = "{}_cond".format(base_name)
68-
ctx.make_const(cond_name, np.ones((), dtype=bool))
69+
nodes.append(ctx.make_const(cond_name, np.ones((), dtype=bool)))
6970

7071
# body
7172
g = ctx.create_new_graph_with_same_config()

tf2onnx/function/select.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ def select_op8(ctx, node, name, args):
9292
def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_input_ids, rank):
9393
nodes = []
9494
cond_var_name = utils.make_name("cond_var")
95-
g.make_const(cond_var_name, np.array(True, dtype=np.bool))
95+
nodes.append(g.make_const(cond_var_name, np.array(True, dtype=np.bool)))
9696

9797
# Loop requires at least a variable, add a useless fake variable.
9898
fake_val_name = utils.make_name("fake_var")
99-
g.make_const(fake_val_name, np.array(0.0, dtype=np.float32))
99+
nodes.append(g.make_const(fake_val_name, np.array(0.0, dtype=np.float32)))
100100

101101
if rank < 1:
102102
raise ValueError("rank is < 1")
@@ -130,15 +130,14 @@ def get_inputs_for_current_iteration(g, input_id, iter_index):
130130
def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids,
131131
rank, loop_name):
132132
g = parent_g.create_new_graph_with_same_config()
133-
nodes = []
134133
iter_name = utils.make_name("i")
135134
cond_name = utils.make_name("cond")
136135
fake_var_name = utils.make_name("fake_var")
137136

138137
g.add_graph_input(iter_name, TensorProto.INT64, (1,)) # iteration_num
139138
g.add_graph_input(cond_name, TensorProto.BOOL, ()) # condition
140139
g.add_graph_input(fake_var_name, TensorProto.FLOAT, ()) # loop-carried dependency
141-
140+
nodes = g.get_nodes()
142141
# get the i'th value of condition
143142
cond_input_id = gather_input_ids[0]
144143
new_nodes, cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name)

tf2onnx/function/sparse_softmax_cross_entropy_with_logits.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def sparse_softmax_cross_entropy_with_logits_op(ctx, node, name, args):
4242
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
4343
res = ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], outputs=[node.output[0]], attr={"axes": [1]})
4444

45-
return [onehot, log_softmax, mul1, reduce_sum, mul2, res]
45+
return [const_eye, onehot, log_softmax, mul1, reduce_sum,
46+
const_negative_one, mul2, res]
4647

4748

4849
def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, args):
@@ -95,5 +96,6 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
9596
inputs=[mul2.output[0]], outputs=[node.output[0]],
9697
attr={"axes": [1]})
9798

98-
nodes.extend([indices_size, indices_unsqueeze, id_unsqueeze, indices_with_id, log_softmax, mul2, res])
99+
nodes.extend([zero_const, one_const, indices_size, indices_unsqueeze, id_unsqueeze, indices_with_id,
100+
log_softmax, const_negative_one, mul2, res])
99101
return nodes

0 commit comments

Comments
 (0)