Skip to content

Commit 261818a

Browse files
authored
Merge pull request #340 from pengwa/set_nodes_change
Graph node management in real time when add/delete
2 parents 87f5b0a + 38542b9 commit 261818a

23 files changed

+531
-744
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
265265
# convert model to onnx
266266
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
267267
input_names=inputs.keys())
268+
model_proto = onnx_graph.make_model("converted from tf2onnx")
268269
new_model_proto = GraphUtil.opt_transposes_with_graph(onnx_graph, "test", debug=debug)
269270
if new_model_proto:
270271
model_proto = new_model_proto

tests/run_pretrained_models.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ inception_v3_slim:
9191
outputs:
9292
- InceptionV3/Predictions/Softmax:0
9393
rtol: 0.02
94-
atol: 0.000001
94+
atol: 0.00001
9595

9696
googlenet_v1_nonslim:
9797
disabled: true

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def test_matmul3(self):
474474
y = tf.placeholder(tf.float32, x_shape, name=_TFINPUT1)
475475
x_ = tf.matmul(x, y, transpose_b=True)
476476
_ = tf.identity(x_, name=_TFOUTPUT)
477-
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: x_val}, rtol=1e-6)
477+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: x_val}, rtol=1e-5)
478478

479479
@check_onnxruntime_incompatibility("Sub")
480480
def test_sub(self):

tests/test_internals.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,8 @@ def test_insert_node1(self):
8686
graph_proto = self.sample_net()
8787
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
8888
n2 = g.get_node_by_name("n2")
89-
n7 = g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7")
89+
g.insert_new_node_on_input(n2, "Abs", "n1:0", name="n7")
9090
ops = g.get_nodes()
91-
ops.append(n7)
9291
g.topological_sort(ops)
9392
result = onnx_to_graphviz(g)
9493
expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
@@ -102,9 +101,8 @@ def test_insert_node1(self):
102101
def test_insert_node2(self):
103102
graph_proto = self.sample_net()
104103
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
105-
n7 = g.insert_new_node_on_output("Abs", "n1:0", name="n7")
104+
g.insert_new_node_on_output("Abs", "n1:0", name="n7")
106105
ops = g.get_nodes()
107-
ops.append(n7)
108106
g.topological_sort(ops)
109107
result = onnx_to_graphviz(g)
110108
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
@@ -119,12 +117,14 @@ def test_remove_input(self):
119117
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
120118
n4 = g.get_node_by_name("n4")
121119
g.remove_input(n4, n4.input[1])
120+
ops = g.get_nodes()
121+
g.topological_sort(ops)
122122
result = onnx_to_graphviz(g)
123-
expected = 'digraph { n1 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] n4 [op_type=Add] ' \
124-
'n5 [op_type=Abs] n6 [op_type=Identity] graph_outputs_Identity__3 ' \
125-
'[op_type=Identity] Placeholder__4 [op_type=Placeholder] input -> n1 ' \
126-
'n1:0 -> n2 n1:0 -> n3 n2:0 -> n4 n4:0 -> n5 raw_output___2:0 -> n6 ' \
127-
'raw_output___2:0 -> graph_outputs_Identity__3 }'
123+
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \
124+
'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
125+
'graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
126+
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 ' \
127+
'raw_output___2:0 -> graph_outputs_Identity__3 raw_output___2:0 -> n6 }'
128128
self.assertEqual(expected, result)
129129

130130
def test_rewrite_subgraph(self):
@@ -143,7 +143,9 @@ def test_rewrite_subgraph(self):
143143
op_name = utils.make_name("ReplacedOp")
144144
out_name = utils.port_name(op_name)
145145
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
146-
ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
146+
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
147+
for n in set(match.get_nodes()):
148+
g.remove_node(n.name)
147149
g.topological_sort(ops)
148150
result = onnx_to_graphviz(g)
149151
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \

tf2onnx/function/gathernd.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
1919
# for (int i = 0; i < size(index); i++)
2020
# gather_res = gather(gather_cur, index[i])
2121
scope_name = utils.make_name("gathernd_inner_loop")
22-
nodes = []
2322
trip_node = ctx.make_node("Size", [index.output[0]])
24-
nodes.append(trip_node)
2523
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
26-
nodes.append(cond_const)
2724
trip_name = utils.make_name("i")
2825
cond_name = utils.make_name("cond")
2926
cond_out_name = utils.make_name("cond_out")
@@ -34,9 +31,8 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
3431
g = ctx.create_new_graph_with_same_config()
3532
index_i = g.make_node("Gather", [index.output[0], trip_name], attr={"axis": 0})
3633
gather = g.make_node("Gather", [cur_name, index_i.output[0]], attr={"axis": 0})
37-
squeeze = g.make_node("Squeeze", [gather.output[0]], attr={"axes": [0]}, outputs=[result_name])
38-
cond = g.make_node("Identity", [cond_name], outputs=[cond_out_name])
39-
g.set_nodes([index_i, gather, squeeze, cond])
34+
g.make_node("Squeeze", [gather.output[0]], attr={"axes": [0]}, outputs=[result_name])
35+
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
4036

4137
g.add_graph_input(trip_name, TensorProto.INT64, [])
4238
g.add_graph_input(cond_name, TensorProto.BOOL, [])
@@ -50,15 +46,13 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
5046
params],
5147
op_name_scope=scope_name, skip_conversion=False)
5248
inner_loop.set_body_graph_as_attr("body", g)
53-
nodes.append(inner_loop)
54-
return nodes, inner_loop
49+
return inner_loop
5550

5651

5752
def make_gathernd(ctx, params, indices, output, scope_name, t_params):
5853
"""make GatherNd op."""
5954
# Tparams output = GatherNd(Tparams params, Tidx indices)
6055
scope_name = utils.make_name(scope_name)
61-
nodes = []
6256
# reshape indices into [sum(indices[:-1]), indices[-1]]
6357
indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
6458
indices_size = ctx.make_node("Size", [indices])
@@ -74,12 +68,11 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
7468
attr={"axis": 0},
7569
dtypes=[TensorProto.INT64])
7670
flatten_indices = ctx.make_node("Reshape", [indices, flatten_shape.output[0]])
77-
nodes.extend([indices_shape, indices_size, inner_shape, outter_shape, flatten_shape, flatten_indices])
71+
7872
# outter loop for each index
7973
# for (int i=0; i<outter_shape; i++) inner_loop(params, flatten_indices[i])
8074
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
8175
dummy_const = ctx.make_const(utils.make_name("dummy"), np.ones((), dtype=np.int64))
82-
nodes.extend([cond_const, dummy_const])
8376

8477
# body graph creation
8578
g = ctx.create_new_graph_with_same_config()
@@ -93,12 +86,10 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
9386
index = g.make_node("Gather", [flatten_indices.output[0], trip_name], attr={"axis": 0})
9487
index_squeeze = g.make_node("Squeeze", [index.output[0]], attr={"axes": [0]})
9588
# inner loop to gather result
96-
nodes_to_append, inner_loop = _make_gathernd_inner_loop(g, params, index_squeeze, t_params)
97-
g.set_nodes(nodes_to_append +
98-
[index, index_squeeze,
99-
g.make_node("Identity", [cond_name], outputs=[cond_out_name]),
100-
g.make_node("Identity", [dummy_name], outputs=[dummy_out_name]),
101-
g.make_node("Identity", [inner_loop.output[0]], outputs=[result_name])])
89+
inner_loop = _make_gathernd_inner_loop(g, params, index_squeeze, t_params)
90+
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
91+
g.make_node("Identity", [dummy_name], outputs=[dummy_out_name])
92+
g.make_node("Identity", [inner_loop.output[0]], outputs=[result_name])
10293

10394
g.add_graph_input(trip_name, TensorProto.INT64, [])
10495
g.add_graph_input(cond_name, TensorProto.BOOL, [])
@@ -113,7 +104,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
113104
output_count=2,
114105
op_name_scope=scope_name, skip_conversion=False)
115106
gathernd_loop.set_body_graph_as_attr("body", g)
116-
nodes.append(gathernd_loop)
107+
117108
# reshape to target shape
118109
# output shape of gathernd: indices.shape[:-1] + gathernd_output.shape[1:]
119110
inner_loop_shape = ctx.make_node("Shape", [gathernd_loop.output[1]], dtypes=[TensorProto.INT64])
@@ -140,18 +131,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params):
140131
[output_shape_.output[0]],
141132
attr={"axes": [0], "ends": [-1], "starts": [0]},
142133
dtypes=[TensorProto.INT64])
143-
output_reshape = ctx.make_node("Reshape",
144-
[gathernd_loop.output[1], output_shape.output[0]],
145-
outputs=[output])
146-
nodes.extend([indices_outter_shape,
147-
inner_loop_shape,
148-
one_const,
149-
inner_loop_shape_,
150-
output_inner_shape,
151-
output_shape_,
152-
output_shape,
153-
output_reshape])
154-
return nodes
134+
ctx.make_node("Reshape", [gathernd_loop.output[1], output_shape.output[0]], outputs=[output])
155135

156136

157137
def gathernd_op(ctx, node, name, args):
@@ -163,4 +143,4 @@ def gathernd_op(ctx, node, name, args):
163143
# same as the attr Tparams
164144
t_params = ctx.get_dtype(params)
165145
utils.make_sure(t_params, "Dtype of {} is None".format(indices))
166-
return make_gathernd(ctx, params, indices, output, name, t_params)
146+
make_gathernd(ctx, params, indices, output, name, t_params)

tf2onnx/function/matrixbandpart.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,22 @@ def matrixbandpart_op(ctx, node, name, args):
2020
# methods to generate mask matrix: if lower triangular is needed, then generate column one by one
2121
# otherwise row is generated one by one.
2222
axis, counter_axis, squeeze_axis = (1, 0, 2) if bandpart == [-1, 0] else (0, 1, 1)
23-
nodes = []
2423
# 1: subgraph to implement tf.onelike(input[:, 0]),
2524
# no need to worry about the dtype, because bool type is needed as Xor only support bool
2625
node_name = utils.make_name("const_zero")
2726
const_zero = ctx.make_const(name=node_name, np_val=np.array([0]).astype(np.int32))
28-
nodes.append(const_zero)
2927
first_col_or_row = ctx.make_node(op_type="Gather", inputs=[node.input[0], const_zero.output[0]],
3028
attr={"axis": axis})
31-
nodes.append(first_col_or_row)
3229
first_col_or_row_casted = ctx.make_node(op_type="Cast", inputs=first_col_or_row.output,
3330
attr={"to": onnx_pb.TensorProto.BOOL})
34-
nodes.append(first_col_or_row_casted)
3531
# line means one col or one row
3632
zero_line = ctx.make_node(op_type="Xor", inputs=first_col_or_row_casted.output*2)
37-
nodes.append(zero_line)
3833
one_line = ctx.make_node(op_type="Not", inputs=zero_line.output)
39-
nodes.append(one_line)
4034

4135
# 2: "loop" to generate mask matrix: generate col or row of matrix one by one
4236
g = ctx.create_new_graph_with_same_config()
4337
node_name = utils.make_name("const_zero_bool")
4438
const_zero_bool = ctx.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
45-
nodes.append(const_zero_bool)
4639
ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
4740

4841
# shift right the line and add zero at the left.
@@ -51,11 +44,10 @@ def matrixbandpart_op(ctx, node, name, args):
5144
slice_node = g.make_node(op_type="Slice", inputs=[new_line.output[0]],
5245
attr={"axes": [counter_axis], "starts": [0], "ends": [-1]})
5346

54-
body_nodes = [slice_node, new_line,
55-
g.make_node("Identity", ["cond"], outputs=["cond_out"]),
56-
g.make_node("Identity", ["line"], outputs=["res"]),
57-
g.make_node("Identity", [slice_node.output[0]], outputs=["line_out"])]
58-
g.set_nodes(body_nodes)
47+
g.make_node("Identity", ["cond"], outputs=["cond_out"])
48+
g.make_node("Identity", ["line"], outputs=["res"])
49+
g.make_node("Identity", [slice_node.output[0]], outputs=["line_out"])
50+
5951
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
6052
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])
6153
g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1])
@@ -66,35 +58,28 @@ def matrixbandpart_op(ctx, node, name, args):
6658

6759
# initial value of body vars
6860
shape = ctx.make_node(op_type="Shape", inputs=[node.input[0]]) # dtype of result is int64
69-
nodes.append(shape)
7061
node_name = utils.make_name("line_num_index")
7162
col_or_row_num_index = ctx.make_const(name=node_name, np_val=np.array(axis).astype(np.int32))
72-
nodes.append(col_or_row_num_index)
7363
line_num = ctx.make_node(op_type="Gather", inputs=[shape.output[0], col_or_row_num_index.output[0]])
74-
nodes.append(line_num)
7564
trip_cnt = line_num.output[0]
7665
node_name = utils.make_name("true")
7766
cond = ctx.make_const(name=node_name, np_val=np.array(1).astype(np.bool))
78-
nodes.append(cond)
7967
col_init = one_line.output[0]
8068

8169
loop_node = ctx.make_node(op_type="Loop", inputs=[trip_cnt, cond.output[0], col_init], output_count=2)
8270
loop_node.set_body_graph_as_attr("body", g)
83-
nodes.append(loop_node)
8471
# convert generated mask matrix from bool to right shape and data type
8572
squeeze = ctx.make_node(op_type="Squeeze", inputs=[loop_node.output[1]], attr={"axes": [squeeze_axis]})
86-
nodes.append(squeeze)
8773
cast1 = ctx.make_node(op_type="Cast", inputs=squeeze.output, attr={"to": onnx_pb.TensorProto.FLOAT})
88-
nodes.append(cast1)
8974
if axis == 1:
9075
mask_matrix = ctx.make_node(op_type="Transpose", inputs=cast1.output)
91-
nodes.append(mask_matrix)
9276
else:
9377
mask_matrix = squeeze
9478
cast2 = ctx.make_node(op_type="Cast", inputs=mask_matrix.output,
9579
attr={"to": ctx.get_dtype(node.input[0])})
96-
nodes.append(cast2)
97-
res = ctx.make_node(op_type="Mul", inputs=[cast2.output[0], node.input[0]],
98-
name=node.name, outputs=node.output)
99-
nodes.append(res)
100-
return nodes
80+
shapes = node.output_shapes
81+
dtypes = node.output_dtypes
82+
ctx.remove_node(node.name)
83+
ctx.make_node(op_type="Mul", inputs=[cast2.output[0], node.input[0]],
84+
name=node.name, outputs=node.output, shapes=shapes,
85+
dtypes=dtypes)

tf2onnx/function/range.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# pylint: disable=unused-argument,missing-docstring
1212

1313

14-
def make_range_const(ctx, start, limit, delta, output, scope_name, dtype):
14+
def make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype):
1515
"""make Range subgraph if all inputs are const."""
1616
# T range = Range(T start, T limit, T delta)
1717
# V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
@@ -21,59 +21,47 @@ def make_range_const(ctx, start, limit, delta, output, scope_name, dtype):
2121
delta = ctx.get_node_by_output(delta).get_tensor_value(as_list=False)
2222
val = np.arange(start, limit, delta, dtype=start.dtype)
2323
const_range = ctx.make_const(base_name, val)
24-
return [ctx.make_node("Identity", [const_range.output[0]], dtypes=[dtype], outputs=[output]),
25-
const_range]
24+
ctx.make_node("Identity", [const_range.output[0]], shapes=[shape], dtypes=[dtype], outputs=[output])
2625

2726

28-
def make_range_non_const(ctx, start, limit, delta, output, scope_name, dtype):
27+
def make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype):
2928
"""make Range subgraph."""
3029
# T range = Range(T start, T limit, T delta)
3130
# V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
3231
base_name = utils.make_name(scope_name)
3332

34-
nodes = []
35-
3633
# trip_count
3734
diff_node = ctx.make_node("Sub",
3835
[limit, start],
3936
op_name_scope=base_name,
4037
name=utils.make_name("diff"))
4138
diff_output = diff_node.output[0]
42-
nodes.append(diff_node)
4339

4440
delta_cast = delta
4541
if dtype in [TensorProto.INT32, TensorProto.INT64]:
4642
cast_node = ctx.make_node("Cast", [diff_output], op_name_scope=base_name,
4743
name="cast_diff", attr={"to": TensorProto.FLOAT})
48-
nodes.append(cast_node)
4944
diff_output = cast_node.output[0]
5045

5146
cast_node = ctx.make_node("Cast", [delta], op_name_scope=base_name, name="cast_delta",
5247
attr={"to": TensorProto.FLOAT})
53-
nodes.append(cast_node)
5448
delta_cast = cast_node.output[0]
55-
5649
div_node = ctx.make_node("Div", [diff_output, delta_cast], op_name_scope=base_name, name="div")
57-
nodes.append(div_node)
58-
5950
ceil_node = ctx.make_node("Ceil", [div_node.output[0]], op_name_scope=base_name, name="ceil")
60-
nodes.append(ceil_node)
61-
6251
trip_count_node = ctx.make_node("Cast", [ceil_node.output[0]], op_name_scope=base_name, name="trip_cnt",
6352
attr={"to": TensorProto.INT64})
64-
nodes.append(trip_count_node)
6553

6654
# cond
6755
# Use initializer here since Constant OP before opset 9 does not support bool type
6856
cond_name = "{}_cond".format(base_name)
69-
nodes.append(ctx.make_const(cond_name, np.ones((), dtype=bool)))
57+
ctx.make_const(cond_name, np.ones((), dtype=bool))
7058

7159
# body
7260
g = ctx.create_new_graph_with_same_config()
73-
body_nodes = [g.make_node("Identity", ["cond"], outputs=["cond_out"]),
74-
g.make_node("Add", ["prev", delta], outputs=["current"], name=utils.make_name("add")),
75-
g.make_node("Identity", ["prev"], outputs=["range"])]
76-
g.set_nodes(body_nodes)
61+
g.make_node("Identity", ["cond"], outputs=["cond_out"])
62+
g.make_node("Add", ["prev", delta], outputs=["current"], name=utils.make_name("add"))
63+
g.make_node("Identity", ["prev"], outputs=["range"])
64+
7765
g.add_graph_input("i", TensorProto.INT64, [])
7866
g.add_graph_input("cond", TensorProto.BOOL, [])
7967
g.add_graph_input("prev", dtype, [])
@@ -86,25 +74,25 @@ def make_range_non_const(ctx, start, limit, delta, output, scope_name, dtype):
8674
loop_inputs = [trip_count_node.output[0], cond_name, start]
8775
loop_node = ctx.make_node("Loop", loop_inputs, output_count=2, op_name_scope=base_name, name="loop")
8876
loop_node.set_body_graph_as_attr("body", g)
89-
nodes.append(loop_node)
90-
91-
identity_node = ctx.make_node("Identity", [loop_node.output[1]], name=base_name, dtypes=[dtype], outputs=[output])
92-
nodes.append(identity_node)
9377

94-
return nodes
78+
ctx.make_node("Identity", [loop_node.output[1]], name=base_name, shapes=[shape],
79+
dtypes=[dtype], outputs=[output])
9580

9681

97-
def make_range(ctx, start, limit, delta, output, scope_name, dtype):
98-
if all(ctx.get_node_by_output(n).is_const() for n in [start, limit, delta]):
99-
return make_range_const(ctx, start, limit, delta, output, scope_name, dtype)
100-
return make_range_non_const(ctx, start, limit, delta, output, scope_name, dtype)
82+
def make_range(ctx, start, limit, delta, output, scope_name, shape, dtype):
83+
if all(ctx.get_node_by_output(n).is_const() for n in [start, limit, delta]) is True:
84+
make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
85+
else:
86+
make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
10187

10288

10389
def range_op7(ctx, node, name, args):
10490
"""Range."""
10591
# T range = Range(T start, T limit, T delta)
10692
# V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
10793
dtype = node.get_attr_int("Tidx")
94+
shape = node.output_shapes[0]
10895
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
109-
return make_range(ctx, node.input[0], node.input[1], node.input[2],
110-
node.output[0], name, dtype)
96+
ctx.remove_node(node.name)
97+
make_range(ctx, node.input[0], node.input[1], node.input[2],
98+
node.output[0], name, shape, dtype)

0 commit comments

Comments
 (0)