|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +"""Loop Optimizer. |
| 5 | + some op in loop's body graph can be moved out to the loop |
| 6 | +""" |
| 7 | + |
| 8 | +from tf2onnx.utils import make_name, make_sure |
| 9 | +from .optimizer_base import GraphOptimizerBase |
| 10 | + |
| 11 | + |
| 12 | +# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ |
| 13 | + |
| 14 | + |
| 15 | +class LoopOptimizer(GraphOptimizerBase): |
| 16 | + """Loop Optimizer.""" |
| 17 | + |
| 18 | + # a lot of terms used here come from loop's onnx spec |
| 19 | + # https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop |
| 20 | + def __init__(self): # pylint: disable=useless-super-delegation |
| 21 | + super(LoopOptimizer, self).__init__() |
| 22 | + |
| 23 | + def _optimize(self, graph): |
| 24 | + return self._apply_optimization(graph, self._optimize_at_current_graph_level) |
| 25 | + |
| 26 | + def _optimize_at_current_graph_level(self, g): |
| 27 | + has_update = True |
| 28 | + while has_update: |
| 29 | + has_update = False |
| 30 | + nodes = [n for n in g.get_nodes() if n.type == "Loop"] |
| 31 | + for n in nodes: |
| 32 | + has_update_tmp = self._try_move_transpose_out_of_body_graph(n) |
| 33 | + if has_update_tmp: |
| 34 | + has_update = True |
| 35 | + return g |
| 36 | + |
| 37 | + @staticmethod |
| 38 | + def consumer_nodes_num(graph, node): |
| 39 | + make_sure(len(node.output) == 1, "only consider node with only one output") |
| 40 | + res = len(graph.find_output_consumers(node.output[0])) |
| 41 | + return res |
| 42 | + |
| 43 | + def _try_move_transpose_out_of_body_graph(self, loop_node): |
| 44 | + # output node of body graph can be loop-carried-dependent, if so it can't be move out of the body graph |
| 45 | + # return True if moving some nodes successfully |
| 46 | + # for now, we only consider moving transpose |
| 47 | + body_graph = loop_node.get_body_graphs()["body"] |
| 48 | + parent_graph = loop_node.graph |
| 49 | + scan_nodes_name_in_body, scan_node_in_parent = self._scan_outputs(loop_node) |
| 50 | + scan_nodes = [body_graph.get_node_by_output(name) for name in scan_nodes_name_in_body] |
| 51 | + graph_is_changed = False |
| 52 | + for node, name_in_parent in zip(scan_nodes, scan_node_in_parent): |
| 53 | + # 1 delete node in body graph if possible |
| 54 | + # only consider two case: trans is output, or transpose > identity > output |
| 55 | + need_process = False |
| 56 | + if node.type == "Transpose" and self.consumer_nodes_num(body_graph, node) <= 1: |
| 57 | + trans = node |
| 58 | + new_output = node.input[0] |
| 59 | + body_graph.remove_node(node.name) |
| 60 | + need_process = True |
| 61 | + elif node.type == "Identity" and node.inputs[0].type == "Transpose" \ |
| 62 | + and self.consumer_nodes_num(body_graph, node) <= 1\ |
| 63 | + and self.consumer_nodes_num(body_graph, node.inputs[0]) <= 1: |
| 64 | + trans = node.inputs[0] |
| 65 | + new_output = node.inputs[0].input[0] |
| 66 | + body_graph.remove_node(node.inputs[0].name) |
| 67 | + body_graph.remove_node(node.name) |
| 68 | + need_process = True |
| 69 | + |
| 70 | + if need_process: |
| 71 | + # 2 correct body graph's output |
| 72 | + body_outputs = body_graph.outputs |
| 73 | + body_outputs[body_outputs.index(node.output[0])] = new_output |
| 74 | + # 3 insert new node in parent graph |
| 75 | + ori_perm = list(trans.get_attr("perm").ints) |
| 76 | + new_perm = [0] + [i + 1 for i in ori_perm] # body output's rank is m > rank of loop's output is m+1 |
| 77 | + name = make_name("trans_moved_from_loop_body") |
| 78 | + _ = parent_graph.insert_new_node_on_output("Transpose", name_in_parent, name, perm=new_perm) |
| 79 | + graph_is_changed = True |
| 80 | + |
| 81 | + return graph_is_changed |
| 82 | + |
| 83 | + @classmethod |
| 84 | + def _scan_outputs(cls, loop): |
| 85 | + # loop has 2+N inputs; loop has N+K outputs; |
| 86 | + # loop's body graph has 1+N+K outputs |
| 87 | + loop_carried = len(loop.input) - 2 |
| 88 | + body_graph = loop.get_body_graphs()["body"] |
| 89 | + return body_graph.outputs[loop_carried + 1:], loop.output[loop_carried:] |
0 commit comments