Skip to content

Commit 7922978

Browse files
Add ReduceMean/ReduceMax -> GlobalAveragePool/GlobalMaxPool optimizer (#1465)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent fac2c91 commit 7922978

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,14 @@ def func(x):
17901790
return tf.identity(x_, name=_TFOUTPUT)
17911791
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
17921792

1793+
def test_reducemax_global_max_pool(self):
1794+
for keepdims in [True, False]:
1795+
x_val = make_xval((2, 3, 4, 5, 6))
1796+
def func(x):
1797+
x_ = tf.reduce_max(x, axis=[2, 3, 4], keepdims=keepdims)
1798+
return tf.add(x_, 0, name=_TFOUTPUT)
1799+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1800+
17931801
@skip_caffe2_backend()
17941802
def test_reduceprod(self):
17951803
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
@@ -1805,6 +1813,14 @@ def func(x):
18051813
return tf.identity(x_, name=_TFOUTPUT)
18061814
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
18071815

1816+
def test_reducemean_global_avg_pool(self):
1817+
for keepdims in [True, False]:
1818+
x_val = make_xval((2, 3, 4, 5))
1819+
def func(x):
1820+
x_ = tf.reduce_mean(x, axis=[2, 3], keepdims=keepdims)
1821+
return tf.add(x_, 0, name=_TFOUTPUT)
1822+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1823+
18081824
@skip_caffe2_backend()
18091825
@check_onnxruntime_incompatibility("Pow")
18101826
def test_pow_scalar(self):

tf2onnx/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .upsample_optimizer import UpsampleOptimizer
1919
from .const_dequantize_optimizer import ConstDequantizeOptimizer
2020
from .reshape_optimizer import ReshapeOptimizer
21+
from .global_pool_optimizer import GlobalPoolOptimizer
2122
from .. import logging
2223

2324
# optimizer sequence need to be considered carefully
@@ -33,6 +34,7 @@
3334
("reshape_optimizer", ReshapeOptimizer),
3435
("remove_identity", IdentityOptimizer),
3536
("remove_back_to_back", BackToBackOptimizer),
37+
("global_pool_optimizer", GlobalPoolOptimizer),
3638
])
3739

3840

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""global pool optimizer
5+
Replaces ReduceMean and ReduceMax patterns with GlobalAveragePool and GlobalMaxPool
6+
"""
7+
8+
from onnx import TensorProto
9+
from tf2onnx.graph_builder import GraphBuilder
10+
from .optimizer_base import GraphOptimizerBase
11+
12+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
13+
14+
15+
class GlobalPoolOptimizer(GraphOptimizerBase):
16+
17+
def __init__(self): # pylint: disable=useless-super-delegation
18+
super(GlobalPoolOptimizer, self).__init__()
19+
20+
def _optimize(self, graph):
21+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
22+
23+
def _optimize_at_current_graph_level(self, graph):
24+
graph_changed = True
25+
while graph_changed:
26+
graph_changed = False
27+
ops = graph.get_nodes()
28+
for op in ops:
29+
if op.type in ["ReduceMean", "ReduceMax"] and self._optimize_reduce(op, graph):
30+
graph_changed = True
31+
self.graph_been_opt = True
32+
return graph
33+
34+
def _optimize_reduce(self, node, graph):
35+
if graph.get_dtype(node.output[0]) not in [TensorProto.FLOAT, TensorProto.DOUBLE]:
36+
return False
37+
if node.output[0] in graph.outputs:
38+
# Replacement is unsafe
39+
return False
40+
axes = node.get_attr_value('axes')
41+
inp_rank = graph.get_rank(node.input[0])
42+
if inp_rank is None:
43+
return False
44+
if axes != list(range(2, inp_rank)):
45+
return False
46+
op_map = {"ReduceMean": "GlobalAveragePool", "ReduceMax": "GlobalMaxPool"}
47+
node.type = op_map[node.type]
48+
del node.attr['axes']
49+
if not node.get_attr_value('keepdims', True):
50+
out_shapes = node.output_shapes
51+
out_dtypes = node.output_dtypes
52+
new_out_shape = graph.get_shape(node.input[0])[:2] + [1] * len(axes)
53+
graph.set_shape(node.output[0], new_out_shape)
54+
squeeze_node = GraphBuilder(graph).make_squeeze(
55+
{'data': node.output[0], 'axes': axes}, shapes=out_shapes, dtypes=out_dtypes,
56+
return_node=True, op_name_scope=node.name)
57+
graph.insert_node_on_output(squeeze_node, node.output[0])
58+
if 'keepdims' in node.attr:
59+
del node.attr['keepdims']
60+
return True

0 commit comments

Comments
 (0)