Skip to content

Commit 4326e02

Browse files
committed
add graph optimizer - const_fold_optimizer
1 parent f2289e3 commit 4326e02

File tree

4 files changed

+165
-1
lines changed

4 files changed

+165
-1
lines changed

tests/test_optimizers.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from backend_test_base import Tf2OnnxBackendTestBase
1515
from common import unittest_main, group_nodes_by_type
1616

17-
1817
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
1918

2019
class OptimizerTests(Tf2OnnxBackendTestBase):
@@ -423,5 +422,69 @@ def test_duplicated_need_multiple_run(self):
423422
op_type="Log", remaining_op_num=3)
424423
# Merge Duplicated Nodes Optimizer Tests End
425424

425+
# Const Fold Optimizer Tests Start
426+
427+
def test_const_fold_trans_with_const1(self):
428+
shape = (6, 6)
429+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
430+
vals=np.random.randn(*shape).flatten().astype(np.float32))
431+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
432+
node2 = helper.make_node("Transpose", ["const"], ["value1"])
433+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
434+
435+
graph = helper.make_graph(
436+
[node1, node2, node3],
437+
"test_const_fold_trans_with_const1",
438+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
439+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
440+
)
441+
442+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
443+
self.run_transpose_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)},
444+
model_proto, remaining_transpose_num=0)
445+
446+
def test_const_fold_trans_with_const2(self):
447+
# need multiple optimization run
448+
shape = (6, 6)
449+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
450+
vals=np.random.randn(*shape).flatten().astype(np.float32))
451+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
452+
node2 = helper.make_node("Transpose", ["const"], ["value1"])
453+
node3 = helper.make_node("Transpose", ["value1"], ["value2"])
454+
node4 = helper.make_node("Add", ["value2", "X"], ["res"])
455+
456+
graph = helper.make_graph(
457+
[node1, node2, node3, node4],
458+
"test_const_fold_trans_with_const2",
459+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
460+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
461+
)
462+
463+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
464+
self.run_transpose_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)},
465+
model_proto, remaining_transpose_num=0)
466+
467+
def test_const_fold_node_is_output(self):
468+
# need multiple optimization run
469+
shape = (6, 6)
470+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
471+
vals=np.random.randn(*shape).flatten().astype(np.float32))
472+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
473+
node2 = helper.make_node("Transpose", ["const"], ["value1"])
474+
node3 = helper.make_node("Transpose", ["value1"], ["res"])
475+
476+
graph = helper.make_graph(
477+
[node1, node2, node3],
478+
"test_const_fold_node_is_output",
479+
[],
480+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
481+
)
482+
483+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
484+
self.run_transpose_compare(["res"], {},
485+
model_proto, remaining_transpose_num=0)
486+
# Const Fold Optimizer Tests End
487+
488+
426489
if __name__ == "__main__":
427490
unittest_main()

tf2onnx/optimizer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
import traceback
1111
from collections import OrderedDict
1212

13+
from tf2onnx.optimizer.const_fold_optimizer import ConstFoldOptimizer
1314
from tf2onnx.optimizer.identity_optimizer import IdentityOptimizer
1415
from tf2onnx.optimizer.merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
1516
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
1617

18+
1719
# pylint: disable=missing-docstring, broad-except
1820

1921
# optimizer sequence need to be considered carefully
2022
_optimizers = OrderedDict([
2123
("transpose_opt", TransposeOptimizer),
24+
("fold_const", ConstFoldOptimizer),
2225
# merge_duplicated_nodes should be used after transpose_opt
2326
# for transpose_opt may have some trans nodes that can be merge
2427
("merge_duplicated_nodes", MergeDuplicatedNodesOptimizer),
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""const fold Optimizer.
5+
if op's inputs are all const then do op computation when building the graph to improve performance
6+
for example, input of transpose node is const then we can do transpose statically instead of at runtime
7+
"""
8+
9+
import logging
10+
11+
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
12+
from tf2onnx import utils
13+
14+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
15+
16+
_func_map = {}
17+
18+
19+
def _register_func(onnx_op):
20+
def _internal_fun(func):
21+
_func_map[onnx_op] = func
22+
return func
23+
return _internal_fun
24+
25+
26+
class ConstFoldOptimizer(GraphOptimizerBase):
27+
28+
def __init__(self, debug=False):
29+
super(ConstFoldOptimizer, self).__init__("ConstFoldOptimizer", debug)
30+
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
31+
32+
def _optimize(self, graph):
33+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
34+
35+
def _optimize_at_current_graph_level(self, graph):
36+
graph_changed = True
37+
while graph_changed:
38+
graph_changed = False
39+
ops = graph.get_nodes()
40+
for op in ops:
41+
if self._can_skip(op):
42+
continue
43+
if self._fold_node(op, graph):
44+
graph_changed = True
45+
return graph
46+
47+
@staticmethod
48+
def _can_skip(node):
49+
if node.is_const() or node.is_graph_input():
50+
return True
51+
52+
skip_type = ["Identity"]
53+
if node.type in skip_type:
54+
return True
55+
56+
return False
57+
58+
def _fold_node(self, node, graph):
59+
""" if node's input are all const and it's not graph's output then it can be fold.
60+
if node can be fold True will be return indicating that graph is changed
61+
"""
62+
if self._all_inputs_are_const(node.inputs) and not self._is_graph_output(node, graph):
63+
process_func = _func_map.get(node.type, self._try_fold)
64+
return process_func(node, graph)
65+
66+
return False
67+
68+
@staticmethod
69+
def _all_inputs_are_const(nodes):
70+
return all(node.is_const() for node in nodes if node)
71+
72+
@staticmethod
73+
def _is_graph_output(node, graph):
74+
node_out_set = set(node.output)
75+
graph_out_set = set(graph.outputs)
76+
return node_out_set.intersection(graph_out_set)
77+
78+
@staticmethod
79+
def _replace_node_with_const(node, graph, val):
80+
const_node = graph.make_const(utils.make_name("const_fold_opt"), val)
81+
graph.replace_all_inputs(graph.get_nodes(), node.output[0], const_node.output[0])
82+
graph.remove_node(node.name)
83+
84+
@staticmethod
85+
@_register_func(onnx_op="Transpose")
86+
def _fold_transpose(node, graph):
87+
const_val = node.inputs[0].get_tensor_value(as_list=False)
88+
perm_attr = node.get_attr("perm")
89+
perm = perm_attr.ints if perm_attr else None
90+
const_val_after_trans = const_val.transpose(perm)
91+
ConstFoldOptimizer._replace_node_with_const(node, graph, const_val_after_trans)
92+
return True
93+
94+
def _try_fold(self, node, graph):
95+
self._log.warning("need to add function to fold op %s whose op_type is %s", node.name, node.type)
96+
return False

tf2onnx/optimizer/optimizer_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def __init__(self, name, debug=False):
1717
def optimize(self, graph):
1818
original_node_statistics = graph.dump_node_statistics()
1919
graph = self._optimize(graph)
20+
graph.delete_unused_nodes(graph.outputs)
21+
graph.topological_sort(graph.get_nodes())
2022
node_statistics = graph.dump_node_statistics()
2123
self._print_stat_diff(original_node_statistics, node_statistics)
2224
return graph

0 commit comments

Comments
 (0)