|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from functools import lru_cache |
| 8 | +from typing import Callable, List, Optional |
| 9 | + |
| 10 | +import executorch.backends.vulkan.utils as utils |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn.functional as F |
| 14 | + |
| 15 | +from executorch.backends.transforms.utils import get_param_tensor, is_param_node |
| 16 | + |
| 17 | +from executorch.backends.vulkan.patterns.pattern_registry import ( |
| 18 | + register_pattern_graph, |
| 19 | + register_pattern_replacement, |
| 20 | +) |
| 21 | + |
| 22 | +from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge |
| 23 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 24 | + |
| 25 | +from torch.export import export |
| 26 | +from torch.fx.passes.utils.matcher_utils import InternalMatch |
| 27 | + |
| 28 | +from torchao.quantization.granularity import PerGroup |
| 29 | +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ |
| 30 | +from torchao.utils import unwrap_tensor_subclass |
| 31 | + |
| 32 | + |
| 33 | +class TorchAOWeightOnlyQuantizedLinearPattern(torch.nn.Module): |
| 34 | + """ |
| 35 | + Quantized linear pattern produced when quantizing linear layers using |
| 36 | + `torchao.quantization.quant_api.quantize_()` with IntxWeightOnlyConfig. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + in_features: int = 512, |
| 42 | + out_features: int = 256, |
| 43 | + bias: bool = False, |
| 44 | + group_size: int = 64, |
| 45 | + weight_bits: int = 4, |
| 46 | + granularity_class: Optional[Callable] = None, |
| 47 | + ) -> None: |
| 48 | + super().__init__() |
| 49 | + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) |
| 50 | + self.group_size = group_size |
| 51 | + self.weight_bits = weight_bits |
| 52 | + |
| 53 | + if self.weight_bits == 4: |
| 54 | + # pyre-ignore[16] |
| 55 | + self.weight_dtype = torch.int4 |
| 56 | + else: |
| 57 | + self.weight_dtype = torch.int8 |
| 58 | + |
| 59 | + if granularity_class is not None: |
| 60 | + self.quant_granularity = granularity_class(self.group_size) |
| 61 | + else: |
| 62 | + self.quant_granularity = PerGroup(self.group_size) |
| 63 | + |
| 64 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 65 | + return self.linear(x) |
| 66 | + |
| 67 | + def apply_quantization(self): |
| 68 | + q_config = IntxWeightOnlyConfig( |
| 69 | + weight_dtype=self.weight_dtype, |
| 70 | + granularity=self.quant_granularity, |
| 71 | + ) |
| 72 | + quantize_(self, q_config) |
| 73 | + unwrap_tensor_subclass(self) |
| 74 | + return self |
| 75 | + |
| 76 | + |
| 77 | +@lru_cache(maxsize=None) |
| 78 | +@register_pattern_graph("torchao_wo_quantized_linear") |
| 79 | +def get_torchao_wo_quantized_linear_graphs() -> List[torch.fx.GraphModule]: |
| 80 | + graphs = [] |
| 81 | + |
| 82 | + # Different configurations to test |
| 83 | + configs = [ |
| 84 | + # gemv pattern |
| 85 | + (1, 1, 128, 128, False, 64, 4, PerGroup), |
| 86 | + # gemm pattern |
| 87 | + (1, 8, 128, 128, False, 64, 4, PerGroup), |
| 88 | + ] |
| 89 | + |
| 90 | + for ( |
| 91 | + batch_size, |
| 92 | + seq_len, |
| 93 | + in_features, |
| 94 | + out_features, |
| 95 | + bias, |
| 96 | + group_size, |
| 97 | + weight_bits, |
| 98 | + granularity_class, |
| 99 | + ) in configs: |
| 100 | + for dtype in [torch.float32]: |
| 101 | + xs = [] |
| 102 | + xs.append(torch.randn(batch_size, seq_len, in_features, dtype=dtype)) |
| 103 | + if batch_size == 1: |
| 104 | + xs.append(torch.randn(seq_len, in_features, dtype=dtype)) |
| 105 | + |
| 106 | + for x in xs: |
| 107 | + # Create and quantize the pattern |
| 108 | + pattern = TorchAOWeightOnlyQuantizedLinearPattern( |
| 109 | + in_features=in_features, |
| 110 | + out_features=out_features, |
| 111 | + bias=bias, |
| 112 | + group_size=group_size, |
| 113 | + weight_bits=weight_bits, |
| 114 | + granularity_class=granularity_class, |
| 115 | + ) |
| 116 | + |
| 117 | + # Apply quantization |
| 118 | + pattern = pattern.apply_quantization() |
| 119 | + |
| 120 | + # Export the quantized pattern |
| 121 | + edge = to_edge( |
| 122 | + export( |
| 123 | + pattern, |
| 124 | + (x,), |
| 125 | + ), |
| 126 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 127 | + ) |
| 128 | + gm = edge.exported_program().graph_module |
| 129 | + graphs.append(gm) |
| 130 | + |
| 131 | + return graphs |
| 132 | + |
| 133 | + |
| 134 | +def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: |
| 135 | + """ |
| 136 | + Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed |
| 137 | + weight tensor by packing 2 4-bit values in one unsigned 8-bit value. |
| 138 | +
|
| 139 | + An input weight tensor of shape (M, K) will produce a packed weight tensor of shape |
| 140 | + (M, K / 2). |
| 141 | +
|
| 142 | + The packing implemented here is the same as the packing produced by |
| 143 | + backends/vulkan/_passes/int4_weight_only_quantizer.py |
| 144 | + """ |
| 145 | + |
| 146 | + # Assert we got a properly quantized tensor. |
| 147 | + min, max = inp.min().item(), inp.max().item() |
| 148 | + assert ( |
| 149 | + max <= 7 and min >= -8 |
| 150 | + ), f"pack_4bit_weight_tensor: [min,max] out of [-8, 7] range, got [{min}, {max}]" |
| 151 | + |
| 152 | + # Assuming we have a 2d tensor |
| 153 | + if inp.ndim != 2: |
| 154 | + inp = inp.squeeze() |
| 155 | + assert ( |
| 156 | + inp.ndim == 2 |
| 157 | + ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {inp.ndim}" |
| 158 | + |
| 159 | + # pad ic |
| 160 | + if inp.shape[-1] % 2 != 0: |
| 161 | + inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) |
| 162 | + |
| 163 | + # Shape after padding |
| 164 | + oc, ic = inp.shape |
| 165 | + assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" |
| 166 | + |
| 167 | + # Adjust inp tensor for zp |
| 168 | + inp = inp.to(dtype=torch.uint8) + 8 |
| 169 | + # Pack each 4-bit value into a single 8-bit value |
| 170 | + return inp[::, ::2] << 4 | inp[::, 1::2] |
| 171 | + |
| 172 | + |
| 173 | +def make_combined_scales_and_zeros_tensor( |
| 174 | + scales: torch.Tensor, zeros: torch.Tensor |
| 175 | +) -> torch.Tensor: |
| 176 | + """ |
| 177 | + Given a scales and zeros tensor, create a combined tensor by stacking them into a |
| 178 | + single tensor. |
| 179 | +
|
| 180 | + The scales and zeros tensors are expected to be 2D tensors of shape |
| 181 | + (OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape |
| 182 | + (NUM_GROUPS, OUTPUT_CHANNELS, 2). |
| 183 | +
|
| 184 | + This is the scales and zeros format produced by |
| 185 | + backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales |
| 186 | + and zeros format expected by the _weight_int4pack_mm op in ATen. |
| 187 | + """ |
| 188 | + scales_reshaped = scales.transpose(0, 1).unsqueeze(2) |
| 189 | + zeros_reshaped = zeros.transpose(0, 1).unsqueeze(2) |
| 190 | + |
| 191 | + zeros_scaled = zeros_reshaped * scales_reshaped * -1 |
| 192 | + return torch.cat((scales_reshaped, zeros_scaled), dim=2) |
| 193 | + |
| 194 | + |
| 195 | +def identify_wo_quantized_linear_io_nodes( # noqa: C901 |
| 196 | + ep: ExportedProgram, |
| 197 | + graph_module: torch.fx.GraphModule, |
| 198 | + match: InternalMatch, |
| 199 | +) -> Optional[List[torch.fx.Node]]: |
| 200 | + dequant_node = None |
| 201 | + # First, find the dequant node |
| 202 | + for node in match.nodes_map.values(): |
| 203 | + if utils.is_dequant_node(node): |
| 204 | + dequant_node = node |
| 205 | + break |
| 206 | + |
| 207 | + if dequant_node is None: |
| 208 | + return None |
| 209 | + |
| 210 | + quantized_weight = dequant_node.args[0] |
| 211 | + quant_scales = dequant_node.args[2] |
| 212 | + quant_zeros = dequant_node.args[3] |
| 213 | + |
| 214 | + if not isinstance(quantized_weight, torch.fx.Node) or not is_param_node( |
| 215 | + ep, quantized_weight |
| 216 | + ): |
| 217 | + return None |
| 218 | + if not isinstance(quant_scales, torch.fx.Node) or not is_param_node( |
| 219 | + ep, quant_scales |
| 220 | + ): |
| 221 | + return None |
| 222 | + if not isinstance(quant_zeros, torch.fx.Node) or not is_param_node(ep, quant_zeros): |
| 223 | + return None |
| 224 | + |
| 225 | + input_nodes = match.placeholder_nodes |
| 226 | + if len(input_nodes) != 4: |
| 227 | + return None |
| 228 | + |
| 229 | + in_tensor_node = None |
| 230 | + for node in input_nodes: |
| 231 | + if node not in dequant_node.args: |
| 232 | + in_tensor_node = node |
| 233 | + break |
| 234 | + |
| 235 | + if in_tensor_node is None: |
| 236 | + return None |
| 237 | + |
| 238 | + output_nodes = match.returning_nodes |
| 239 | + |
| 240 | + if len(output_nodes) != 1: |
| 241 | + return None |
| 242 | + |
| 243 | + out_tensor_node = output_nodes[0] |
| 244 | + if not isinstance(out_tensor_node, torch.fx.Node): |
| 245 | + return None |
| 246 | + |
| 247 | + return [ |
| 248 | + in_tensor_node, |
| 249 | + quantized_weight, |
| 250 | + quant_scales, |
| 251 | + quant_zeros, |
| 252 | + out_tensor_node, |
| 253 | + ] |
| 254 | + |
| 255 | + |
| 256 | +# wo = "weight only" |
| 257 | +@register_pattern_replacement("torchao_wo_quantized_linear") |
| 258 | +def create_wo_quantized_linear_custom_op( |
| 259 | + ep: ExportedProgram, |
| 260 | + graph_module: torch.fx.GraphModule, |
| 261 | + match: InternalMatch, |
| 262 | +): |
| 263 | + io_nodes = identify_wo_quantized_linear_io_nodes(ep, graph_module, match) |
| 264 | + if io_nodes is None: |
| 265 | + return |
| 266 | + |
| 267 | + assert len(io_nodes) == 5 |
| 268 | + in_tensor, quantized_weight, quant_scales, quant_zeros, out_tensor = io_nodes |
| 269 | + |
| 270 | + quantized_weight_tensor = get_param_tensor(ep, quantized_weight) |
| 271 | + if not isinstance(quantized_weight_tensor, torch.Tensor): |
| 272 | + return |
| 273 | + packed_quantized_weight_tensor = pack_4bit_weight_tensor(quantized_weight_tensor) |
| 274 | + utils.update_program_state_dict( |
| 275 | + ep, quantized_weight.name, packed_quantized_weight_tensor |
| 276 | + ) |
| 277 | + quantized_weight.meta["val"] = quantized_weight.meta["val"][:, ::2].to(torch.uint8) |
| 278 | + |
| 279 | + quant_scales_tensor = get_param_tensor(ep, quant_scales) |
| 280 | + quant_zeros_tensor = get_param_tensor(ep, quant_zeros) |
| 281 | + |
| 282 | + assert quantized_weight_tensor is not None |
| 283 | + assert quant_scales_tensor is not None |
| 284 | + assert quant_zeros_tensor is not None |
| 285 | + |
| 286 | + group_size = quantized_weight_tensor.shape[1] // quant_scales_tensor.shape[1] |
| 287 | + |
| 288 | + combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor( |
| 289 | + quant_scales_tensor, quant_zeros_tensor |
| 290 | + ) |
| 291 | + |
| 292 | + combined_scales_zeros_name = f"{quantized_weight.name}_scales_zeros" |
| 293 | + graph_module.register_parameter( |
| 294 | + combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor) |
| 295 | + ) |
| 296 | + |
| 297 | + with graph_module.graph.inserting_before(out_tensor): |
| 298 | + combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name) |
| 299 | + wo_qlinear = graph_module.graph.create_node( |
| 300 | + "call_function", |
| 301 | + exir_ops.edge.et_vk.linear_weight_int4.default, |
| 302 | + args=(in_tensor, quantized_weight, group_size, combined_scales_zeros, 1), |
| 303 | + ) |
| 304 | + |
| 305 | + if hasattr(out_tensor, "meta") and "val" in out_tensor.meta: |
| 306 | + wo_qlinear.meta["val"] = out_tensor.meta["val"] |
| 307 | + |
| 308 | + out_tensor.replace_all_uses_with(wo_qlinear) |
0 commit comments