Skip to content

Commit 68e6abb

Browse files
committed
review feedack, merge master
2 parents ccc6dfd + 11c9246 commit 68e6abb

File tree

11 files changed

+204
-35
lines changed

11 files changed

+204
-35
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) Microsoft Corporation
3+
Copyright (c) ONNX Project Contributors
44
All rights reserved.
55

66
Permission is hereby granted, free of charge, to any person obtaining a copy

tests/test_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,14 +1912,14 @@ def test_ceil(self):
19121912
def test_softplus(self):
19131913
x_val = np.array([-1, 0, 1], dtype=np.float32)
19141914
x = tf.placeholder(tf.float32, [3], name=_TFINPUT)
1915-
x_ = tf.math.softplus(x)
1915+
x_ = tf.nn.softplus(x)
19161916
_ = tf.identity(x_, name=_TFOUTPUT)
19171917
self._run_test_case([_OUTPUT], {_INPUT: x_val})
19181918

19191919
def test_softsign(self):
19201920
x_val = np.array([-1, 0, 1], dtype=np.float32)
19211921
x = tf.placeholder(tf.float32, [3], name=_TFINPUT)
1922-
x_ = tf.math.softsign(x)
1922+
x_ = tf.nn.softsign(x)
19231923
_ = tf.identity(x_, name=_TFOUTPUT)
19241924
self._run_test_case([_OUTPUT], {_INPUT: x_val})
19251925

tests/test_optimizers.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from backend_test_base import Tf2OnnxBackendTestBase
1515
from common import unittest_main, group_nodes_by_type
1616

17-
1817
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
1918

2019
class OptimizerTests(Tf2OnnxBackendTestBase):
@@ -423,5 +422,69 @@ def test_duplicated_need_multiple_run(self):
423422
op_type="Log", remaining_op_num=3)
424423
# Merge Duplicated Nodes Optimizer Tests End
425424

425+
# Const Fold Optimizer Tests Start
426+
427+
def test_const_fold_trans_with_const1(self):
428+
shape = (6, 6)
429+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
430+
vals=np.random.randn(*shape).flatten().astype(np.float32))
431+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
432+
node2 = helper.make_node("Transpose", ["const"], ["value1"])
433+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
434+
435+
graph = helper.make_graph(
436+
[node1, node2, node3],
437+
"test_const_fold_trans_with_const1",
438+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
439+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
440+
)
441+
442+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
443+
self.run_transpose_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)},
444+
model_proto, remaining_transpose_num=0)
445+
446+
def test_const_fold_trans_with_const2(self):
447+
# need multiple optimization run
448+
shape = (6, 6)
449+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
450+
vals=np.random.randn(*shape).flatten().astype(np.float32))
451+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
452+
node2 = helper.make_node("Transpose", ["const"], ["value1"])
453+
node3 = helper.make_node("Transpose", ["value1"], ["value2"])
454+
node4 = helper.make_node("Add", ["value2", "X"], ["res"])
455+
456+
graph = helper.make_graph(
457+
[node1, node2, node3, node4],
458+
"test_const_fold_trans_with_const2",
459+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
460+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
461+
)
462+
463+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
464+
self.run_transpose_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)},
465+
model_proto, remaining_transpose_num=0)
466+
467+
def test_const_fold_node_is_output(self):
468+
# need multiple optimization run
469+
shape = (6, 6)
470+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
471+
vals=np.random.randn(*shape).flatten().astype(np.float32))
472+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
473+
node2 = helper.make_node("Transpose", ["const"], ["value1"])
474+
node3 = helper.make_node("Transpose", ["value1"], ["res"])
475+
476+
graph = helper.make_graph(
477+
[node1, node2, node3],
478+
"test_const_fold_node_is_output",
479+
[],
480+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
481+
)
482+
483+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
484+
self.run_transpose_compare(["res"], {},
485+
model_proto, remaining_transpose_num=0)
486+
# Const Fold Optimizer Tests End
487+
488+
426489
if __name__ == "__main__":
427490
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,16 @@ def version_4(cls, ctx, node, **kwargs):
133133
else:
134134
del node.attr["axis"]
135135

136-
shape = ctx.get_shape(node.input[0])
137136
if axis and axis.ints:
138137
axis = axis.ints
139138
neg_axis = any([val < 0 for val in axis])
140139
if neg_axis:
140+
shape = ctx.get_shape(node.input[0])
141141
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
142142
shape_len = len(shape)
143143
axis = [a + shape_len if a < 0 else a for a in axis]
144144
else:
145+
shape = ctx.get_shape(node.input[0])
145146
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
146147
axis = [i for i, j in enumerate(shape) if j == 1]
147148
node.set_attr("axes", axis)

tf2onnx/optimizer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
import traceback
1111
from collections import OrderedDict
1212

13+
from tf2onnx.optimizer.const_fold_optimizer import ConstFoldOptimizer
1314
from tf2onnx.optimizer.identity_optimizer import IdentityOptimizer
1415
from tf2onnx.optimizer.merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
1516
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
1617

18+
1719
# pylint: disable=missing-docstring, broad-except
1820

1921
# optimizer sequence need to be considered carefully
2022
_optimizers = OrderedDict([
2123
("transpose_opt", TransposeOptimizer),
24+
("fold_const", ConstFoldOptimizer),
2225
# merge_duplicated_nodes should be used after transpose_opt
2326
# for transpose_opt may have some trans nodes that can be merge
2427
("merge_duplicated_nodes", MergeDuplicatedNodesOptimizer),
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""const fold Optimizer.
5+
if op's inputs are all const then do op computation when building the graph to improve performance
6+
for example, input of transpose node is const then we can do transpose statically instead of at runtime
7+
"""
8+
9+
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
10+
from tf2onnx import utils
11+
12+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
13+
14+
# key is op_type, value is the function to compute outputs
15+
# the schema of function is: inputs are(node, graph), output is a list of constant values.
16+
_func_map = {}
17+
18+
19+
def _register_func(op_type):
20+
def _internal_fun(func):
21+
_func_map[op_type] = func
22+
return func
23+
return _internal_fun
24+
25+
26+
class ConstFoldOptimizer(GraphOptimizerBase):
27+
28+
def __init__(self, debug=False):
29+
super(ConstFoldOptimizer, self).__init__("ConstFoldOptimizer", debug)
30+
31+
def _optimize(self, graph):
32+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
33+
34+
def _optimize_at_current_graph_level(self, graph):
35+
graph_changed = True
36+
while graph_changed:
37+
graph_changed = False
38+
ops = graph.get_nodes()
39+
for op in ops:
40+
if self._should_skip(op):
41+
continue
42+
if self._fold_node(op, graph):
43+
graph_changed = True
44+
return graph
45+
46+
@staticmethod
47+
def _should_skip(node):
48+
# only support onnx official op for now, op in other domain is not supported for now
49+
if not utils.is_onnx_domain(node.domain):
50+
return True
51+
52+
if node.is_const() or node.is_graph_input():
53+
return True
54+
55+
skip_type = ["Identity"]
56+
if node.type in skip_type:
57+
return True
58+
59+
return False
60+
61+
def _fold_node(self, node, graph):
62+
""" if node's input are all const and it's not graph's output then it can be fold.
63+
if node can be fold True will be return indicating that graph is changed
64+
"""
65+
if self._all_inputs_are_const(node.inputs) and not self._is_graph_output(node, graph):
66+
process_func = _func_map.get(node.type, None)
67+
if process_func:
68+
const_outputs = process_func(node, graph)
69+
self._replace_node_with_const(node, graph, const_outputs)
70+
return True
71+
self.log.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
72+
return False
73+
74+
@staticmethod
75+
def _all_inputs_are_const(nodes):
76+
return all(node.is_const() for node in nodes if node)
77+
78+
@staticmethod
79+
def _is_graph_output(node, graph):
80+
node_out_set = set(node.output)
81+
graph_out_set = set(graph.outputs)
82+
return node_out_set.intersection(graph_out_set)
83+
84+
@staticmethod
85+
def _replace_node_with_const(node, graph, vals):
86+
utils.make_sure(len(node.output) == len(vals), "length of node outputs and const vals should be same")
87+
for old_input, val in zip(node.output, vals):
88+
const_node = graph.make_const(utils.make_name("const_fold_opt"), val)
89+
graph.set_dtype(const_node.output[0], utils.map_numpy_to_onnx_dtype(val.dtype))
90+
graph.set_shape(const_node.output[0], val.shape)
91+
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0])
92+
graph.remove_node(node.name)
93+
94+
@staticmethod
95+
@_register_func("Transpose")
96+
def _fold_transpose(node, graph) -> list:
97+
const_val = node.inputs[0].get_tensor_value(as_list=False)
98+
perm_attr = node.get_attr("perm")
99+
perm = perm_attr.ints if perm_attr else None
100+
const_val_after_trans = const_val.transpose(perm)
101+
return [const_val_after_trans]

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,9 @@
66
"""
77

88
from __future__ import unicode_literals
9-
import logging
109

1110
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1211

13-
14-
log = logging.getLogger("tf2onnx.optimizer.identity_optimizer")
15-
16-
1712
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
1813

1914

@@ -22,7 +17,6 @@ class IdentityOptimizer(GraphOptimizerBase):
2217

2318
def __init__(self, debug=False):
2419
super(IdentityOptimizer, self).__init__("IdentityOptimizer", debug)
25-
2620
self._g = None
2721

2822
def optimize(self, graph):
@@ -31,8 +25,8 @@ def optimize(self, graph):
3125
self._optimize_recursively(self._g)
3226
current_counter = self._g.dump_node_statistics()
3327
identity_cnt = current_counter["Identity"]
34-
current_counter.subtract(previous_counter)
35-
log.info(" %d identity op(s) left, ops diff after identity optimization: %s", identity_cnt, current_counter)
28+
self.log.info(" %d identity op(s) left", identity_cnt)
29+
self._print_stat_diff(previous_counter, current_counter)
3630
return self._g
3731

3832
def _optimize_recursively(self, g):
@@ -42,9 +36,9 @@ def _optimize_recursively(self, g):
4236
body_graphs = n.get_body_graphs()
4337
if body_graphs:
4438
for attr, b_g in body_graphs.items():
45-
log.debug("start handling subgraph of %s's attribute %s", n.name, attr)
39+
self.log.debug("start handling subgraph of %s's attribute %s", n.name, attr)
4640
self._optimize_recursively(b_g)
47-
log.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
41+
self.log.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
4842

4943
def _optimize(self, g):
5044
has_update = True
@@ -53,7 +47,7 @@ def _optimize(self, g):
5347
nodes = [n for n in g.get_nodes() if n.type == "Identity"]
5448
for n in nodes:
5549
if n.graph is None:
56-
log.info("node has been removed from this graph, skip")
50+
self.log.info("node has been removed from this graph, skip")
5751
continue
5852

5953
graph_outputs = set(n.output).intersection(g.outputs)
@@ -72,19 +66,18 @@ def _handle_non_graph_output_identity(graph, identity):
7266
graph.remove_node(identity.name)
7367
return True
7468

75-
@staticmethod
76-
def _handle_graph_output_identity(graph, identity, graph_outputs):
69+
def _handle_graph_output_identity(self, graph, identity, graph_outputs):
7770
input_id = identity.input[0]
7871
input_node = identity.inputs[0]
7972

8073
if input_node.graph != graph:
8174
# If input node is in parent graph, we don't handle it now
82-
log.debug("input node in parent graph, skip")
75+
self.log.debug("input node in parent graph, skip")
8376
return False
8477

8578
if input_node.is_graph_input():
8679
# Identity between input and output should not be removed.
87-
log.debug("skip identity between input and output")
80+
self.log.debug("skip identity between input and output")
8881
return False
8982

9083
output_id = identity.output[0]
@@ -93,7 +86,7 @@ def _handle_graph_output_identity(graph, identity, graph_outputs):
9386
if input_id in graph.outputs:
9487
# input id already be graph output, so we cannot make that be another graph output.
9588
# this Identity must be kept.
96-
log.debug("identity input already be graph output")
89+
self.log.debug("identity input already be graph output")
9790
return False
9891

9992
graph.remove_node(identity.name)

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
from collections import defaultdict, namedtuple
11-
import logging
1211

1312
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1413

@@ -23,7 +22,6 @@ class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
2322
"""
2423
def __init__(self, debug=False):
2524
super(MergeDuplicatedNodesOptimizer, self).__init__("MergeDuplicatedNodesOptimizer", debug)
26-
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
2725
# used internally
2826
self._graph_can_be_optimized = True
2927

tf2onnx/optimizer/optimizer_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Graph Optimizer Base"""
55

66
from __future__ import unicode_literals
7+
import logging
78

89

910
class GraphOptimizerBase(object):
@@ -13,10 +14,12 @@ class GraphOptimizerBase(object):
1314
def __init__(self, name, debug=False):
1415
self._debug = debug
1516
self._name = name
17+
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
1618

1719
def optimize(self, graph):
1820
original_node_statistics = graph.dump_node_statistics()
1921
graph = self._optimize(graph)
22+
graph.delete_unused_nodes(graph.outputs)
2023
node_statistics = graph.dump_node_statistics()
2124
self._print_stat_diff(original_node_statistics, node_statistics)
2225
return graph
@@ -28,6 +31,10 @@ def _optimize(self, graph):
2831
def name(self):
2932
return self._name
3033

34+
@property
35+
def log(self):
36+
return self._log
37+
3138
@staticmethod
3239
def _apply_optimization(graph, optimize_func):
3340
"""
@@ -52,4 +59,4 @@ def _print_stat_diff(self, nodes_original, nodes_after_optimized):
5259
for key, value in nodes_after_optimized.items():
5360
if value != 0:
5461
res[key] = value
55-
self._log.info("after optimized, the optimization_statistics is %s", res)
62+
self.log.info("the optimization gain is %s", res)

0 commit comments

Comments
 (0)