|
| 1 | +# Copyright (c) Qualcomm Innovation Center, Inc. |
| 2 | +# All rights reserved |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import types |
| 7 | +from contextlib import contextmanager |
| 8 | + |
| 9 | +import torch |
| 10 | +import torchao |
| 11 | +from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( |
| 12 | + PerBlockParamObserver, |
| 13 | +) |
| 14 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 15 | +from torchao.quantization.pt2e import PerChannelMinMaxObserver |
| 16 | + |
| 17 | + |
| 18 | +class SeqMseModule(torch.nn.Module): |
| 19 | + """ |
| 20 | + Args: |
| 21 | + nominal_weight: Tensor |
| 22 | + nominal parameters from operator |
| 23 | + nominal_bias: Tensor |
| 24 | + nominal parameters from operator |
| 25 | + operator: fx.Node |
| 26 | + operator to be executed |
| 27 | + observer: UniformQuantizationObserverBase |
| 28 | + parameter observer (specific for weight) |
| 29 | + num_candidates: int |
| 30 | + grids to search minimal mse loss |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + nominal_weight, |
| 36 | + nominal_bias, |
| 37 | + operator, |
| 38 | + observer, |
| 39 | + num_candidates, |
| 40 | + ): |
| 41 | + super().__init__() |
| 42 | + self.nominal_weight = nominal_weight |
| 43 | + self.nominal_bias = nominal_bias |
| 44 | + self.observer = observer |
| 45 | + self.steps = torch.linspace( |
| 46 | + 1 / num_candidates, 1, steps=num_candidates |
| 47 | + ).tolist() |
| 48 | + self.operator = self._make_operator(operator) |
| 49 | + self.best_candidate_step = 1.0 |
| 50 | + |
| 51 | + def _make_operator(self, aten_op): |
| 52 | + if aten_op.target == torch.ops.aten.conv2d.default: |
| 53 | + stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3] |
| 54 | + padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4] |
| 55 | + dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5] |
| 56 | + groups = 1 if len(aten_op.args) < 7 else aten_op.args[6] |
| 57 | + has_bias = self.nominal_bias is not None |
| 58 | + module = torch.nn.Conv2d( |
| 59 | + in_channels=self.nominal_weight.shape[1], |
| 60 | + out_channels=self.nominal_weight.shape[0], |
| 61 | + kernel_size=self.nominal_weight.shape[-2:], |
| 62 | + stride=stride, |
| 63 | + padding=padding, |
| 64 | + dilation=dilation, |
| 65 | + groups=groups, |
| 66 | + bias=has_bias, |
| 67 | + ) |
| 68 | + module.weight.data = self.nominal_weight |
| 69 | + if has_bias: |
| 70 | + module.bias.data = self.nominal_bias |
| 71 | + return module |
| 72 | + else: |
| 73 | + raise NotImplementedError(f"target of {aten_op.target} is not implemented") |
| 74 | + |
| 75 | + def _per_block_qdq(self, scale, zero_point): |
| 76 | + return torchao.quantization.quant_primitives._fake_quantize_affine( |
| 77 | + input=self.nominal_weight, |
| 78 | + block_size=self.observer.block_size, |
| 79 | + scale=scale, |
| 80 | + zero_point=zero_point, |
| 81 | + quant_dtype=self.observer.dtype, |
| 82 | + quant_min=self.observer.quant_min, |
| 83 | + quant_max=self.observer.quant_max, |
| 84 | + ) |
| 85 | + |
| 86 | + def _per_channel_qdq(self, scale, zero_point): |
| 87 | + return torch.fake_quantize_per_channel_affine( |
| 88 | + input=self.nominal_weight, |
| 89 | + scale=scale, |
| 90 | + zero_point=zero_point, |
| 91 | + axis=0, |
| 92 | + quant_min=self.observer.quant_min, |
| 93 | + quant_max=self.observer.quant_max, |
| 94 | + ) |
| 95 | + |
| 96 | + def _fake_quant(self, scale, zero_point): |
| 97 | + dispatcher = { |
| 98 | + PerChannelMinMaxObserver: self._per_channel_qdq, |
| 99 | + PerBlockParamObserver: self._per_block_qdq, |
| 100 | + } |
| 101 | + return dispatcher[type(self.observer)](scale, zero_point) |
| 102 | + |
| 103 | + def _find_best_candidate(self, nominal_input, nominal_output): |
| 104 | + # calculate current baseline |
| 105 | + scale, zero_point = self.observer.calculate_qparams() |
| 106 | + zero_point = zero_point.to(torch.int32) |
| 107 | + self.operator.weight.data = self._fake_quant(scale, zero_point) |
| 108 | + candidate, current_loss = ( |
| 109 | + 1, |
| 110 | + torch.nn.functional.mse_loss( |
| 111 | + self.operator(nominal_input), nominal_output |
| 112 | + ).item(), |
| 113 | + ) |
| 114 | + for step in self.steps: |
| 115 | + self.operator.weight.data = self._fake_quant(scale * step, zero_point) |
| 116 | + loss = torch.nn.functional.mse_loss( |
| 117 | + self.operator(nominal_input), nominal_output |
| 118 | + ).item() |
| 119 | + if loss < current_loss: |
| 120 | + candidate, current_loss = step, loss |
| 121 | + return candidate |
| 122 | + |
| 123 | + def forward(self, nominal_input, nominal_output): |
| 124 | + self.best_candidate_step = self._find_best_candidate( |
| 125 | + nominal_input=nominal_input, nominal_output=nominal_output |
| 126 | + ) |
| 127 | + |
| 128 | + |
| 129 | +class InsertSeqMse(ExportPass): |
| 130 | + """ |
| 131 | + Insert Seq Mse Observer to find the best quant config for certain node's weight. |
| 132 | + """ |
| 133 | + |
| 134 | + seq_mse_ops = {torch.ops.aten.conv2d.default} |
| 135 | + |
| 136 | + def __init__(self, num_candidates=1000): |
| 137 | + super(InsertSeqMse, self).__init__() |
| 138 | + self.num_candidates = num_candidates |
| 139 | + |
| 140 | + def _insert_seq_mse( |
| 141 | + self, graph_module: torch.fx.GraphModule |
| 142 | + ) -> torch.fx.GraphModule: |
| 143 | + count = 0 |
| 144 | + for node in graph_module.graph.nodes: |
| 145 | + if node.target in self.seq_mse_ops: |
| 146 | + # extract observer |
| 147 | + weight_node_obs = node.args[1] |
| 148 | + observer = getattr(graph_module, weight_node_obs.name) |
| 149 | + # extract parameters |
| 150 | + weight_node = weight_node_obs.args[0] |
| 151 | + weight_tensor = graph_module.get_parameter(weight_node.target).detach() |
| 152 | + bias_tensor = None |
| 153 | + if len(node.args) > 2 and node.args[2] is not None: |
| 154 | + bias_tensor = graph_module.get_parameter( |
| 155 | + node.args[2].args[0].target |
| 156 | + ).detach() |
| 157 | + |
| 158 | + with graph_module.graph.inserting_after(node): |
| 159 | + seq_mse_mod = SeqMseModule( |
| 160 | + nominal_weight=weight_tensor, |
| 161 | + nominal_bias=bias_tensor, |
| 162 | + operator=node, |
| 163 | + observer=observer, |
| 164 | + num_candidates=self.num_candidates, |
| 165 | + ) |
| 166 | + module_name = f"seq_mse_{count}" |
| 167 | + count += 1 |
| 168 | + setattr(graph_module, module_name, seq_mse_mod) |
| 169 | + input_nodes = (node.args[0], node) |
| 170 | + graph_module.graph.create_node( |
| 171 | + "call_module", module_name, input_nodes, {} |
| 172 | + ) |
| 173 | + |
| 174 | + def call(self, graph_module: torch.fx.GraphModule): |
| 175 | + self._insert_seq_mse(graph_module) |
| 176 | + graph_module.recompile() |
| 177 | + return PassResult(graph_module, True) |
| 178 | + |
| 179 | + |
| 180 | +class RemoveSeqMse(ExportPass): |
| 181 | + """ |
| 182 | + Remove Seq Mse before invoking convert_pt2e and update final quantization encoding. |
| 183 | + """ |
| 184 | + |
| 185 | + def __init__(self): |
| 186 | + super(RemoveSeqMse, self).__init__() |
| 187 | + |
| 188 | + def _remove_seq_mse( |
| 189 | + self, graph_module: torch.fx.GraphModule |
| 190 | + ) -> torch.fx.GraphModule: |
| 191 | + node_to_erase = [] |
| 192 | + for node in graph_module.graph.nodes: |
| 193 | + if node.op == "call_module": |
| 194 | + # try extracting SeqMse module |
| 195 | + module = getattr(graph_module, node.target) |
| 196 | + if isinstance(module, SeqMseModule): |
| 197 | + # rewrite observer method for pre-calculated scale |
| 198 | + scale, zero_point = module.observer.calculate_qparams() |
| 199 | + module.observer.updated_encoding = ( |
| 200 | + scale * module.best_candidate_step, |
| 201 | + zero_point, |
| 202 | + ) |
| 203 | + module.observer.calculate_qparams = types.MethodType( |
| 204 | + lambda s: s.updated_encoding, module.observer |
| 205 | + ) |
| 206 | + node_to_erase.append(node) |
| 207 | + |
| 208 | + for node in node_to_erase: |
| 209 | + graph_module.graph.erase_node(node) |
| 210 | + |
| 211 | + def call(self, graph_module: torch.fx.GraphModule): |
| 212 | + self._remove_seq_mse(graph_module) |
| 213 | + graph_module.recompile() |
| 214 | + return PassResult(graph_module, True) |
| 215 | + |
| 216 | + |
| 217 | +@contextmanager |
| 218 | +def SeqMSE(prepared_gm, num_candidates): |
| 219 | + prepared_gm = InsertSeqMse(num_candidates)(prepared_gm).graph_module |
| 220 | + try: |
| 221 | + yield |
| 222 | + finally: |
| 223 | + prepared_gm = RemoveSeqMse()(prepared_gm).graph_module |
0 commit comments