|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | + |
| 4 | +"""q dq optimizer |
| 5 | + Pushes Quantize ops up and Dequantize ops down to maximize DQ -> op -> Q patterns for ORT |
| 6 | + Does not work for per-channel quantization yet |
| 7 | +""" |
| 8 | + |
| 9 | +from .optimizer_base import GraphOptimizerBase |
| 10 | + |
| 11 | +# pylint: disable=logging-not-lazy,unused-argument,missing-docstring |
| 12 | + |
| 13 | + |
| 14 | +class QDQOptimizer(GraphOptimizerBase): |
| 15 | + |
| 16 | + def __init__(self): # pylint: disable=useless-super-delegation |
| 17 | + super(QDQOptimizer, self).__init__() |
| 18 | + |
| 19 | + def _optimize(self, graph): |
| 20 | + return self._apply_optimization(graph, self._optimize_at_current_graph_level) |
| 21 | + |
| 22 | + def _optimize_at_current_graph_level(self, graph): |
| 23 | + graph_changed = True |
| 24 | + while graph_changed: |
| 25 | + graph_changed = False |
| 26 | + ops = graph.get_nodes() |
| 27 | + for op in ops: |
| 28 | + if op.type == "QuantizeLinear" and self._optimize_quantize(op, graph): |
| 29 | + graph_changed = True |
| 30 | + self.graph_been_opt = True |
| 31 | + elif op.type == "DequantizeLinear" and self._optimize_dequantize(op, graph): |
| 32 | + graph_changed = True |
| 33 | + self.graph_been_opt = True |
| 34 | + return graph |
| 35 | + |
| 36 | + def _optimize_quantize(self, quant_node, graph): |
| 37 | + if 'axis' in quant_node.attr: |
| 38 | + return False |
| 39 | + node = quant_node.inputs[0] |
| 40 | + if node.type == "DequantizeLinear": |
| 41 | + # Remove DQ -> Q |
| 42 | + if not self.has_same_quantization_params(quant_node, node): |
| 43 | + return False |
| 44 | + if quant_node.output[0] in graph.outputs or node.output[0] in graph.outputs: |
| 45 | + return False |
| 46 | + graph.replace_all_inputs(quant_node.output[0], node.input[0]) |
| 47 | + if not graph.find_output_consumers(quant_node.output[0]): |
| 48 | + graph.remove_node(quant_node.name) |
| 49 | + if not graph.find_output_consumers(node.output[0]): |
| 50 | + graph.remove_node(node.name) |
| 51 | + return True |
| 52 | + |
| 53 | + # Push quantize nodes up |
| 54 | + tensor_idx = is_tensor_op(graph, node) |
| 55 | + if tensor_idx is None: |
| 56 | + return False |
| 57 | + inp_indices, out_indices = tensor_idx |
| 58 | + for i in out_indices: |
| 59 | + consumers = graph.find_output_consumers(node.output[i]) |
| 60 | + if node.output[i] in graph.outputs: |
| 61 | + return False |
| 62 | + for c in consumers: |
| 63 | + if c.type != "QuantizeLinear": |
| 64 | + return False |
| 65 | + if not self.has_same_quantization_params(c, quant_node): |
| 66 | + return False |
| 67 | + if c.output[0] in graph.outputs: |
| 68 | + return False |
| 69 | + # All outputs are quantized. Push quantization up to input. |
| 70 | + for i in inp_indices: |
| 71 | + inp_q = self.make_q_or_dq(graph, "QuantizeLinear", node.input[i], quant_node, node.name) |
| 72 | + graph.replace_input(node, node.input[i], inp_q.output[0], i) |
| 73 | + |
| 74 | + for i in out_indices: |
| 75 | + graph.copy_dtype(quant_node.output[0], node.output[i]) |
| 76 | + consumers = graph.find_output_consumers(node.output[i]) |
| 77 | + for c in consumers: |
| 78 | + graph.replace_all_inputs(c.output[0], node.output[i]) |
| 79 | + |
| 80 | + return True |
| 81 | + |
| 82 | + def _optimize_dequantize(self, dequant_node, graph): |
| 83 | + if 'axis' in dequant_node.attr: |
| 84 | + return False |
| 85 | + # Push dequantize nodes down |
| 86 | + consumers = graph.find_output_consumers(dequant_node.output[0]) |
| 87 | + for node in consumers: |
| 88 | + if self._optimize_dequantize_and_node(dequant_node, node, graph): |
| 89 | + return True |
| 90 | + return False |
| 91 | + |
| 92 | + def _optimize_dequantize_and_node(self, dequant_node, node, graph): |
| 93 | + tensor_idx = is_tensor_op(graph, node) |
| 94 | + if tensor_idx is None: |
| 95 | + return False |
| 96 | + inp_indices, out_indices = tensor_idx |
| 97 | + for i in inp_indices: |
| 98 | + inp = node.inputs[i] |
| 99 | + if inp.type != "DequantizeLinear": |
| 100 | + return False |
| 101 | + if not self.has_same_quantization_params(inp, dequant_node): |
| 102 | + return False |
| 103 | + if inp.output[0] in graph.outputs: |
| 104 | + return False |
| 105 | + for i in out_indices: |
| 106 | + if node.output[i] in graph.outputs: |
| 107 | + return False |
| 108 | + # All inputs are dequantized. Push dequantization down to output. |
| 109 | + for i in inp_indices: |
| 110 | + # Skip the dequantize on the input |
| 111 | + graph.replace_input(node, node.input[i], node.inputs[i].input[0], i) |
| 112 | + |
| 113 | + for i in out_indices: |
| 114 | + graph.copy_dtype(dequant_node.input[0], node.output[i]) |
| 115 | + out_dq = self.make_q_or_dq(graph, "DequantizeLinear", node.output[i], dequant_node, node.name) |
| 116 | + graph.insert_node_on_output(out_dq, node.output[i]) |
| 117 | + |
| 118 | + return True |
| 119 | + |
| 120 | + def has_same_quantization_params(self, node1, node2): |
| 121 | + if node1.get_attr_value("axis") != node2.get_attr_value("axis"): |
| 122 | + return False |
| 123 | + # Constant merging will ensure these are the same nodes if they are equal |
| 124 | + return node1.input[1:] == node2.input[1:] |
| 125 | + |
| 126 | + def make_q_or_dq(self, graph, op_type, inp, reference_node, name_scope): |
| 127 | + """Makes a QuantizeLinear or DequantizeLinear with quantization params copied from the reference_node""" |
| 128 | + axis = reference_node.get_attr_value("axis") |
| 129 | + if axis is None: |
| 130 | + attr = {} |
| 131 | + else: |
| 132 | + attr = {'axis': axis} |
| 133 | + return graph.make_node(op_type, [inp] + reference_node.input[1:], attr=attr, op_name_scope=name_scope) |
| 134 | + |
| 135 | + |
| 136 | +def is_tensor_op(g, node): |
| 137 | + """Detects ops that reshape/shuffle tensor elements without computing/changing them (Transpose, Gather, etc.) |
| 138 | + Returns None or a tuple (inp_indices, out_indices) s.t. all corresponding outputs of the node depend only |
| 139 | + on elements of the corresponding inputs of the node and all other inputs/outputs are unchanged. |
| 140 | + WARNING: Transpose optimizer pushes tranpose down so be careful when swapping to avoid infinite loop.""" |
| 141 | + if node.type in ["Identity", "Reshape", "Flatten", "Expand", "Transpose", "Squeeze", "Unsqueeze", "Slice"]: |
| 142 | + return ([0], [0]) |
| 143 | + if node.type in ["Gather", "GatherND", "GatherElements"]: |
| 144 | + # Output depends on data if indices is unchanged |
| 145 | + return ([0], [0]) |
| 146 | + if node.type in ["Scatter", "ScatterND", "ScatterElements"]: |
| 147 | + # Output depends on data and updates if indices is unchanged |
| 148 | + return ([0, 2], [0]) |
| 149 | + if node.type == "Concat": |
| 150 | + return (list(range(len(node.input))), [0]) |
| 151 | + if node.type == "Split": |
| 152 | + return ([0], list(range(len(node.output)))) |
| 153 | + if node.type in ["Compress", "Tile", "ReverseSequence", "DepthToSpace"]: |
| 154 | + return ([0], [0]) |
| 155 | + return None |
0 commit comments