Skip to content

Commit c7d199e

Browse files
author
Wenbing Li
authored
Add a small tool to remove cast node from onnx model (#227)
* add a decast tool in the onnxtk. * More fixes. * add unit test. * update the ci build script. * fall back the old onnxruntime for the test.
1 parent 1ea5bfd commit c7d199e

File tree

8 files changed

+162
-145
lines changed

8 files changed

+162
-145
lines changed

.appveyor.yml

Lines changed: 0 additions & 53 deletions
This file was deleted.

.azure-pipelines/linux-conda-CI.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,10 @@ jobs:
3636
conda install -c conda-forge protobuf
3737
conda install -c conda-forge numpy
3838
conda install -c conda-forge cmake
39-
conda install -c conda-forge openmpi
40-
conda install -c conda-forge tensorflow
4139
pip install $(ONNX_PATH)
4240
pip install -r requirements.txt
4341
pip install -r requirements-dev.txt
44-
test '$(python.version)' != '2.7' && pip install onnxruntime
42+
test '$(python.version)' != '2.7' && pip install onnxruntime==0.1.4
4543
pip install pytest
4644
git clone --recursive https://github.com/cjlin1/libsvm libsvm
4745
cd libsvm

.travis.yml

Lines changed: 0 additions & 61 deletions
This file was deleted.

onnxutils/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
<p align="center"><img width="40%" src="../docs/ONNXMLTools_logo_main.png" /></p>
2-
3-
41
# Introduction
52
ONNXTK package enables you check and optimize the [ONNX](https://onnx.ai) models for ONNX inference engine.
63

onnxutils/onnxtk/decast.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import sys
2+
import onnx
3+
from .optimizer import LinkedNode, Solution
4+
5+
6+
def remove_cast(lnodes, op_set):
7+
8+
while True:
9+
sln = []
10+
for n_ in lnodes:
11+
if n_.op_type in op_set and n_.in_single_path:
12+
if n_.precedence[0].op_type == 'Cast' and n_.successor[0].op_type == 'Cast':
13+
sln.append(Solution(None, n_.precedence[0], n_.precedence[0], n_))
14+
sln.append(Solution(n_, n_.successor[0], n_.successor[0], None))
15+
break
16+
17+
if len(sln) == 0:
18+
break
19+
20+
for s_ in sln:
21+
lnodes = s_.apply(lnodes)
22+
23+
return lnodes
24+
25+
26+
def decast(origin_model, oplist):
27+
"""
28+
remove the ONNX cast op from the specified operator.
29+
:param origin_model:these
30+
:param oplist:
31+
:return:
32+
"""
33+
graph = origin_model.graph
34+
nodelist = list(graph.node)
35+
del graph.node[:]
36+
37+
all_nodes = LinkedNode.build_from_onnx(nodelist,
38+
[],
39+
[i_.name for i_ in graph.input],
40+
[o_.name for o_ in graph.output])
41+
42+
nodes = remove_cast(all_nodes, set(oplist))
43+
for n_ in nodes:
44+
graph.node.extend(n_.generate())
45+
46+
return origin_model
47+
48+
49+
def main():
50+
if len(sys.argv) < 4:
51+
print('decast.py model_in model_out <op1, ...>')
52+
return
53+
54+
input = sys.argv[1]
55+
output = sys.argv[2]
56+
op_list = sys.argv[3:]
57+
58+
oxml = onnx.load_model(input)
59+
oxml = decast(oxml, op_list)
60+
onnx.save_model(oxml, output)
61+
62+
63+
if __name__ == "__main__":
64+
main()

onnxutils/onnxtk/optimizer.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def in_redirect(self, old_name, name):
129129
self.input[key] = name
130130

131131
def out_redirect(self, old_name, name):
132+
assert self.in_or_out
132133
if old_name in self.output:
133134
self.output[old_name] = name
134135
else:
@@ -191,10 +192,10 @@ def build_from_onnx(onnx_nodes, nchw_inputs, inputs, outputs):
191192
if var_ in nchw_inputs:
192193
nnode = LinkedNode(
193194
helper.make_node(
194-
'Transpose',
195-
[var_],
196-
[new_output],
197-
perm=[0, 2, 3, 1]))
195+
'Transpose',
196+
[var_],
197+
[new_output],
198+
perm=[0, 2, 3, 1]))
198199
var_map[new_output] = nnode
199200
nnode.add_precedence(target, var_)
200201
n_.in_redirect(var_, new_output)
@@ -236,6 +237,10 @@ def debug_print(node_list):
236237

237238

238239
class Solution(object):
240+
"""
241+
Solution is the base class for solutions, and it has a basic function is to
242+
delete the node range of (begin, begin_n, end_p, end), where 'begin' and 'end' are excluded.
243+
"""
239244
def __init__(self, begin, begin_n, end_p, end):
240245
self.begin = begin
241246
self.begin_n = begin_n
@@ -255,23 +260,64 @@ def is_useless_transpose(perm):
255260
return perm == list(six.moves.range(len(perm)))
256261

257262
@staticmethod
258-
def delete_node(node_list, begin, node, end): # type: ([],LinkedNode, LinkedNode, LinkedNode)->[]
263+
def delete_node_nto1(node_list, begin, node, end): # type: ([],LinkedNode, LinkedNode, LinkedNode)->[]
264+
"""
265+
delete the node which has n-input and 1-output
266+
"""
267+
if begin is None:
268+
assert node is not None
269+
begin = node.precedence
270+
elif not isinstance(begin, list):
271+
begin = [begin]
272+
259273
if end.in_or_out:
274+
# if the end is output node, the output name will be kept to avoid the model output name updating.
275+
for nb_ in begin:
276+
nb_.out_redirect(node.single_input, node.single_output)
277+
else:
278+
for nb_ in begin:
279+
target_var_name = node.single_input
280+
assert target_var_name in nb_.output.values() # since the output info never be updated, except the final.
281+
end.in_redirect(node.single_output, target_var_name)
282+
283+
for nb_ in begin:
284+
nb_.successor = [end if v_ == node else v_ for v_ in nb_.successor]
285+
end.precedence = [v_ for v_ in end.precedence if v_ != node] + node.precedence
286+
287+
node_list.remove(node)
288+
return node_list
289+
290+
@staticmethod
291+
def delete_node_1ton(node_list, begin, node, end): # type: ([],LinkedNode, LinkedNode, LinkedNode)->[]
292+
"""
293+
delete the node which has 1-input and n-output
294+
"""
295+
if end is None:
296+
assert end is not None
297+
end = node.successor
298+
elif not isinstance(end, list):
299+
end = [end]
300+
301+
if any(e_.in_or_out for e_ in end):
260302
# if the end is output node, the output name will be kept to avoid the model output name updating.
261303
begin.out_redirect(node.single_input, node.single_output)
262304
else:
263-
target_var_name = node.single_input
264-
assert target_var_name in begin.output.values() # since the output info never be updated, except the final.
265-
end.in_redirect(node.single_output, target_var_name)
305+
for ne_ in end:
306+
target_var_name = node.single_input
307+
# since the output info never be updated, except the final.
308+
assert target_var_name in begin.output.values()
309+
ne_.in_redirect(node.single_output, target_var_name)
266310

267-
begin.successor = [end if v_ == node else v_ for v_ in begin.successor]
268-
end.precedence = [begin if v_ == node else v_ for v_ in end.precedence]
311+
begin.successor = [v_ for v_ in begin.successor if v_ != node] + node.successor
312+
for ne_ in end:
313+
ne_.precedence = [begin if v_ == node else v_ for v_ in ne_.precedence]
269314

270315
node_list.remove(node)
271316
return node_list
272317

273318
@staticmethod
274-
def add_siso_node(node_list, begin, end, begin_output_name, node): # type: ([], LinkedNode, LinkedNode, string, LinkedNode)->[]
319+
def add_siso_node(node_list, begin, end, begin_output_name, node):
320+
# type: ([], LinkedNode, LinkedNode, str, LinkedNode)->[]
275321
node.in_redirect(node.single_input, begin_output_name)
276322
end.in_redirect(begin_output_name, node.single_output)
277323
begin.successor[begin.successor.index(end)] = node
@@ -287,8 +333,11 @@ def apply(self, node_list):
287333
while node != self.end:
288334
assert len(node.successor) == 1
289335
end = node.successor[0]
290-
node_list = self.delete_node(node_list, self.begin, node, end)
291-
node = end
336+
if self.begin:
337+
node_list = self.delete_node_nto1(node_list, self.begin, node, end)
338+
else:
339+
node_list = self.delete_node_nto1(node_list, self.begin, node, end)
340+
node = self.end if self.end is None else end
292341

293342
return node_list
294343

@@ -306,10 +355,10 @@ def apply(self, node_list):
306355
# node.reshape_input_for_broadcast(perm0)
307356
node = node.successor[0]
308357

309-
node_list = self.delete_node(node_list, self.begin, self.begin_n, self.begin_n.successor[0])
310-
node_list = self.delete_node(node_list, self.end_p.precedence[0], self.end_p, self.end)
358+
node_list = self.delete_node_1ton(node_list, self.begin, self.begin_n, self.begin_n.successor[0])
359+
node_list = self.delete_node_1ton(node_list, self.end_p.precedence[0], self.end_p, self.end)
311360
else:
312-
node_list = self.delete_node(node_list, self.begin_n, self.end_p, self.end)
361+
node_list = self.delete_node_1ton(node_list, self.begin_n, self.end_p, self.end)
313362
self.begin_n.attribute['perm'] = perm_f
314363
return node_list
315364

@@ -346,7 +395,7 @@ def apply(self, node_list):
346395
FanOutSolution.number = FanOutSolution.number + 1
347396
node_list = Solution.add_siso_node(node_list, self.end_p, suc, list(suc.input.values())[0], nnode)
348397

349-
node_list = Solution.delete_node(node_list, self.begin, self.begin_n, self.end_p)
398+
node_list = Solution.delete_node_1ton(node_list, self.begin, self.begin_n, self.end_p)
350399
return node_list
351400

352401

@@ -368,7 +417,7 @@ def apply(self, node_list):
368417
precedence_list = self.begin.precedence.copy()
369418
node_list = Solution.add_siso_node(node_list, self.begin, self.begin_n, list(self.begin.output.values())[0], nnode)
370419
for branch in precedence_list:
371-
node_list = Solution.delete_node(node_list, branch.precedence[0], branch, self.begin)
420+
node_list = Solution.delete_node_1ton(node_list, branch.precedence[0], branch, self.begin)
372421
return node_list
373422

374423

onnxutils/tests/test_decast.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import unittest
2+
3+
from onnx import helper
4+
from onnx import onnx_pb as onnx_proto
5+
from onnxtk.decast import decast
6+
7+
8+
class DecastTestCase(unittest.TestCase):
9+
10+
def test_decast(self):
11+
nodes = []
12+
nodes[0:] = [helper.make_node('Identity', ['input1'], ['identity1'])]
13+
nodes[1:] = [helper.make_node('Cast', ['identity1'], ['cast0'], to=1)]
14+
nodes[2:] = [helper.make_node('ReduceSum', ['cast0'], ['reduce0'])]
15+
nodes[3:] = [helper.make_node('Cast', ['reduce0'], ['cast1'], to=6)]
16+
nodes[4:] = [helper.make_node('Identity', ['cast1'], ['output0'])]
17+
18+
input0 = helper.make_tensor_value_info('input1', onnx_proto.TensorProto.FLOAT, [1, 1, 2, 3])
19+
output0 = helper.make_tensor_value_info('output0', onnx_proto.TensorProto.FLOAT, [1, 1, 2, 3])
20+
21+
graph = helper.make_graph(nodes, 'test_graph', [input0], [output0])
22+
model = helper.make_model(graph)
23+
self.assertIsNotNone(model)
24+
25+
converted_model = decast(model, ['ReduceSum'])
26+
self.assertTrue(len(converted_model.graph.node) == 3)
27+
28+
29+
if __name__ == '__main__':
30+
unittest.main()

onnxutils/tests/test_opt.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,6 @@ def test_fan_in(self):
190190
self.assertEqual(len(new_nodes), 7)
191191
self.assertIsNotNone(model)
192192

193-
@unittest.skip('Need manually copy the real_yolov3.onnx, which is big.')
194-
def test_onnx_model(self):
195-
with open("./real_yolov3.onnx", 'rb') as fml:
196-
oxml= onnx.load_model(fml) # type: onnx.ModelProto
197-
oxml = optimize_onnx_model(oxml)
198-
onnx.save_model(oxml, 'real_yolov3_opt.onnx')
199-
200193

201194
if __name__ == '__main__':
202195
unittest.main()

0 commit comments

Comments
 (0)