|
| 1 | +# Copyright (c) Qualcomm Innovation Center, Inc. |
| 2 | +# All rights reserved |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +from executorch.backends.qualcomm._passes.utils import append_qdq, copy_meta |
| 9 | +from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter |
| 10 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 11 | +from torch.fx import GraphModule |
| 12 | +from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix |
| 13 | + |
| 14 | + |
| 15 | +def _pad_list_to_4(lst): |
| 16 | + return lst + [1] * (4 - len(lst)) if len(lst) < 4 else lst[:4] |
| 17 | + |
| 18 | + |
| 19 | +class ConvertLinearToConv2d(ExportPass): |
| 20 | + """ |
| 21 | + Replace aten.linear.default with equivalent 1x1 conv2d using call_function nodes. |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__(self, edge_program: torch.export.ExportedProgram): |
| 25 | + super().__init__() |
| 26 | + self.edge_program = edge_program |
| 27 | + self.per_block_dq = torch.ops.torchao.dequantize_affine.default |
| 28 | + |
| 29 | + def _register_tensor( |
| 30 | + self, |
| 31 | + gm: torch.fx.GraphModule, |
| 32 | + node: torch.fx.Node, |
| 33 | + tensor_constant: torch.Tensor, |
| 34 | + ) -> torch.fx.Node: |
| 35 | + new_node_name = get_new_attr_name_with_prefix(node.name)(gm) |
| 36 | + gm.register_buffer(new_node_name, tensor_constant) |
| 37 | + |
| 38 | + with gm.graph.inserting_before(node): |
| 39 | + get_attr_node = gm.graph.get_attr(new_node_name) |
| 40 | + get_attr_node.meta["val"] = tensor_constant |
| 41 | + return get_attr_node |
| 42 | + |
| 43 | + def _append_dq( |
| 44 | + self, |
| 45 | + graph_module: torch.fx.GraphModule, |
| 46 | + node: torch.fx.Node, |
| 47 | + qdq_node: torch.fx.Node, |
| 48 | + ): |
| 49 | + q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default |
| 50 | + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| 51 | + |
| 52 | + if qdq_node.target not in {q_op, dq_op}: |
| 53 | + return node |
| 54 | + |
| 55 | + with graph_module.graph.inserting_after(node): |
| 56 | + dq_args = (node, *qdq_node.args[1:]) |
| 57 | + dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args) |
| 58 | + dq_node.meta = copy_meta(node.meta) |
| 59 | + return dq_node |
| 60 | + |
| 61 | + def _create_node( |
| 62 | + self, graph_module, target, args, meta_node, new_meta_val, qdq_node |
| 63 | + ): |
| 64 | + new_node = graph_module.graph.call_function(target, args) |
| 65 | + new_node.meta = copy_meta( |
| 66 | + meta_node.meta, |
| 67 | + lambda m, new_meta_val=new_meta_val: { |
| 68 | + **m, |
| 69 | + "val": new_meta_val, |
| 70 | + }, |
| 71 | + ) |
| 72 | + dq_node = append_qdq( |
| 73 | + graph_module=graph_module, |
| 74 | + node=new_node, |
| 75 | + qdq_node=qdq_node, |
| 76 | + ) |
| 77 | + return dq_node |
| 78 | + |
| 79 | + def _reshape_weight(self, graph_module, weight_node, dq_node): |
| 80 | + # After export, constant node will be placeholder from edge_program |
| 81 | + weight_val = get_parameter(weight_node, self.edge_program) |
| 82 | + assert weight_val is not None, "Cannot get the weight in linear node." |
| 83 | + |
| 84 | + weight_val = weight_val.reshape(*weight_val.shape, 1, 1) |
| 85 | + # Create the new weight node when several node share the same weight |
| 86 | + # such as embedding and lm_head in LLM. |
| 87 | + if len(list(weight_node.users)) > 1: |
| 88 | + weight_node = self._register_tensor(graph_module, weight_node, weight_val) |
| 89 | + dq_node = self._append_dq(graph_module, weight_node, dq_node) |
| 90 | + else: |
| 91 | + set_parameter( |
| 92 | + ( |
| 93 | + torch.nn.Parameter(weight_val) |
| 94 | + if weight_val.dtype == torch.float |
| 95 | + else weight_val |
| 96 | + ), |
| 97 | + weight_node, |
| 98 | + self.edge_program, |
| 99 | + ) |
| 100 | + |
| 101 | + # Update node meta val |
| 102 | + weight_node.meta["val"] = weight_node.meta["val"].reshape(weight_val.shape) |
| 103 | + dq_node.meta["val"] = dq_node.meta["val"].reshape(weight_val.shape) |
| 104 | + # Update block size for per-block quant |
| 105 | + if dq_node.target == self.per_block_dq: |
| 106 | + new_args = list(dq_node.args) |
| 107 | + # pad block size |
| 108 | + new_args[1] = _pad_list_to_4(list(new_args[1])) |
| 109 | + dq_node.args = tuple(new_args) |
| 110 | + |
| 111 | + return dq_node |
| 112 | + |
| 113 | + def call(self, graph_module: GraphModule): |
| 114 | + graph = graph_module.graph |
| 115 | + |
| 116 | + for node in list(graph.nodes): |
| 117 | + if node.target == torch.ops.aten.linear.default: |
| 118 | + input_node = node.args[0] |
| 119 | + # In quantization flow, weight_arg will be dq node. |
| 120 | + weight_arg = node.args[1] |
| 121 | + weight_node = ( |
| 122 | + weight_arg if weight_arg.op == "placeholder" else weight_arg.args[0] |
| 123 | + ) |
| 124 | + bias_arg = node.args[2] if len(node.args) > 2 else None |
| 125 | + |
| 126 | + input_meta_val = input_node.meta["val"] |
| 127 | + output_meta_val = node.meta["val"] |
| 128 | + if bias_arg: |
| 129 | + bias_meta_val = bias_arg.meta["val"] |
| 130 | + |
| 131 | + rank = input_meta_val.ndim |
| 132 | + with graph.inserting_before(node): |
| 133 | + # Step 1: reshape input |
| 134 | + # rank = 2: (dim, C) -> (1, C, 1, dim) |
| 135 | + # rank = 3: (N, dim, C) -> (N, C, 1, dim) |
| 136 | + # rank = 4: (N, H, W, C) -> (N, C, H, W) |
| 137 | + order = (0, 3, 1, 2) |
| 138 | + if rank <= 3: |
| 139 | + # (dim, C) -> (1, C, 1, dim) |
| 140 | + # (N, dim, C) -> (N, C, 1, dim) |
| 141 | + shape = ( |
| 142 | + (1, *input_meta_val.shape, 1) |
| 143 | + if rank == 2 |
| 144 | + else (*input_meta_val.shape, 1) |
| 145 | + ) |
| 146 | + x_meta_val = input_meta_val.reshape(shape) |
| 147 | + input_node = self._create_node( |
| 148 | + graph_module, |
| 149 | + torch.ops.aten.reshape.default, |
| 150 | + (input_node, shape), |
| 151 | + node, |
| 152 | + x_meta_val, |
| 153 | + input_node, |
| 154 | + ) |
| 155 | + order = (0, 2, 3, 1) |
| 156 | + |
| 157 | + x_meta_val = x_meta_val.permute(order) |
| 158 | + x = self._create_node( |
| 159 | + graph_module, |
| 160 | + torch.ops.aten.permute.default, |
| 161 | + (input_node, order), |
| 162 | + node, |
| 163 | + x_meta_val, |
| 164 | + input_node, |
| 165 | + ) |
| 166 | + |
| 167 | + # Step 2: reshape weight |
| 168 | + weight_arg = self._reshape_weight( |
| 169 | + graph_module, weight_node, weight_arg |
| 170 | + ) |
| 171 | + weight_meta_val = weight_arg.meta["val"] |
| 172 | + |
| 173 | + conv_args = [x, weight_arg] |
| 174 | + conv_args_meta_val = [x_meta_val, weight_meta_val] |
| 175 | + if bias_arg: |
| 176 | + conv_args.append(bias_arg) |
| 177 | + conv_args_meta_val.append(bias_meta_val) |
| 178 | + else: |
| 179 | + conv_args.append(None) |
| 180 | + conv_args_meta_val.append(None) |
| 181 | + |
| 182 | + conv_args.extend( |
| 183 | + [[1, 1], [0, 0], [1, 1], 1] |
| 184 | + ) # stride, padding, dilation, groups |
| 185 | + conv_node_val = torch.nn.functional.conv2d( |
| 186 | + *conv_args_meta_val, |
| 187 | + stride=(1, 1), |
| 188 | + padding=(0, 0), |
| 189 | + dilation=(1, 1), |
| 190 | + groups=1, |
| 191 | + ) |
| 192 | + conv_node = self._create_node( |
| 193 | + graph_module, |
| 194 | + torch.ops.aten.conv2d.default, |
| 195 | + tuple(conv_args), |
| 196 | + node, |
| 197 | + conv_node_val, |
| 198 | + list(node.users)[0], |
| 199 | + ) |
| 200 | + |
| 201 | + # Step 3: restore shape |
| 202 | + # rank = 2: (1, C, 1, dim) -> (dim, C) |
| 203 | + # rank = 3: (N, C, 1, dim) -> (N, dim C) |
| 204 | + # rank = 4: (N, C, H, W) -> (N, H, W, C) |
| 205 | + order = (0, 2, 3, 1) if rank == 4 else (0, 3, 1, 2) |
| 206 | + y_meta_val = conv_node_val.permute(order) |
| 207 | + y = self._create_node( |
| 208 | + graph_module, |
| 209 | + torch.ops.aten.permute.default, |
| 210 | + (conv_node, order), |
| 211 | + node, |
| 212 | + y_meta_val, |
| 213 | + list(node.users)[0], |
| 214 | + ) |
| 215 | + if rank <= 3: |
| 216 | + target_shape = output_meta_val.shape |
| 217 | + y_meta_val = y_meta_val.reshape(target_shape) |
| 218 | + y = self._create_node( |
| 219 | + graph_module, |
| 220 | + torch.ops.aten.reshape.default, |
| 221 | + (y, target_shape), |
| 222 | + node, |
| 223 | + y_meta_val, |
| 224 | + list(node.users)[0], |
| 225 | + ) |
| 226 | + |
| 227 | + node.replace_all_uses_with(y) |
| 228 | + graph.erase_node(node) |
| 229 | + |
| 230 | + graph.eliminate_dead_code() |
| 231 | + graph_module.recompile() |
| 232 | + return PassResult(graph_module, True) |
0 commit comments