|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | + |
| 4 | +"""global pool optimizer |
| 5 | + Replaces ReduceMean and ReduceMax patterns with GlobalAveragePool and GlobalMaxPool |
| 6 | +""" |
| 7 | + |
| 8 | +from onnx import TensorProto |
| 9 | +from tf2onnx.graph_builder import GraphBuilder |
| 10 | +from .optimizer_base import GraphOptimizerBase |
| 11 | + |
| 12 | +# pylint: disable=logging-not-lazy,unused-argument,missing-docstring |
| 13 | + |
| 14 | + |
| 15 | +class GlobalPoolOptimizer(GraphOptimizerBase): |
| 16 | + |
| 17 | + def __init__(self): # pylint: disable=useless-super-delegation |
| 18 | + super(GlobalPoolOptimizer, self).__init__() |
| 19 | + |
| 20 | + def _optimize(self, graph): |
| 21 | + return self._apply_optimization(graph, self._optimize_at_current_graph_level) |
| 22 | + |
| 23 | + def _optimize_at_current_graph_level(self, graph): |
| 24 | + graph_changed = True |
| 25 | + while graph_changed: |
| 26 | + graph_changed = False |
| 27 | + ops = graph.get_nodes() |
| 28 | + for op in ops: |
| 29 | + if op.type in ["ReduceMean", "ReduceMax"] and self._optimize_reduce(op, graph): |
| 30 | + graph_changed = True |
| 31 | + self.graph_been_opt = True |
| 32 | + return graph |
| 33 | + |
| 34 | + def _optimize_reduce(self, node, graph): |
| 35 | + if graph.get_dtype(node.output[0]) not in [TensorProto.FLOAT, TensorProto.DOUBLE]: |
| 36 | + return False |
| 37 | + if node.output[0] in graph.outputs: |
| 38 | + # Replacement is unsafe |
| 39 | + return False |
| 40 | + axes = node.get_attr_value('axes') |
| 41 | + inp_rank = graph.get_rank(node.input[0]) |
| 42 | + if inp_rank is None: |
| 43 | + return False |
| 44 | + if axes != list(range(2, inp_rank)): |
| 45 | + return False |
| 46 | + op_map = {"ReduceMean": "GlobalAveragePool", "ReduceMax": "GlobalMaxPool"} |
| 47 | + node.type = op_map[node.type] |
| 48 | + del node.attr['axes'] |
| 49 | + if not node.get_attr_value('keepdims', True): |
| 50 | + out_shapes = node.output_shapes |
| 51 | + out_dtypes = node.output_dtypes |
| 52 | + new_out_shape = graph.get_shape(node.input[0])[:2] + [1] * len(axes) |
| 53 | + graph.set_shape(node.output[0], new_out_shape) |
| 54 | + squeeze_node = GraphBuilder(graph).make_squeeze( |
| 55 | + {'data': node.output[0], 'axes': axes}, shapes=out_shapes, dtypes=out_dtypes, |
| 56 | + return_node=True, op_name_scope=node.name) |
| 57 | + graph.insert_node_on_output(squeeze_node, node.output[0]) |
| 58 | + if 'keepdims' in node.attr: |
| 59 | + del node.attr['keepdims'] |
| 60 | + return True |
0 commit comments