|
| 1 | +import logging |
| 2 | +from typing import Any, Callable, Dict, Optional, Type |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa |
| 8 | + linear_weight_int4_op, |
| 9 | +) |
| 10 | + |
| 11 | +from torchao.quantization.GPTQ import _check_linear_int4_k |
| 12 | +from torchao.quantization.unified import Quantizer |
| 13 | +from torchao.quantization.utils import groupwise_affine_quantize_tensor |
| 14 | + |
| 15 | + |
| 16 | +# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with |
| 17 | +# changes at the annotated lines. |
| 18 | +class VkWeightOnlyInt4Linear(torch.nn.Module): |
| 19 | + __constants__ = ["in_features", "out_features"] |
| 20 | + in_features: int |
| 21 | + out_features: int |
| 22 | + weight: torch.Tensor |
| 23 | + |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + in_features: int, |
| 27 | + out_features: int, |
| 28 | + # TODO: remove dtype field, not used |
| 29 | + bias=False, |
| 30 | + device=None, |
| 31 | + dtype=None, |
| 32 | + groupsize: int = 128, |
| 33 | + inner_k_tiles: int = 8, |
| 34 | + precision: torch.dtype = torch.bfloat16, |
| 35 | + scales_precision: torch.dtype = torch.bfloat16, |
| 36 | + ) -> None: |
| 37 | + super().__init__() |
| 38 | + self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) |
| 39 | + if self.padding: |
| 40 | + from torchao.quantization.utils import find_multiple |
| 41 | + |
| 42 | + self.origin_in_features = in_features |
| 43 | + in_features = find_multiple(in_features, (1024,)) |
| 44 | + |
| 45 | + self.in_features = in_features |
| 46 | + self.out_features = out_features |
| 47 | + assert not bias, "require bias=False" |
| 48 | + self.device = device |
| 49 | + self.groupsize = groupsize |
| 50 | + self.inner_k_tiles = inner_k_tiles |
| 51 | + self.precision = precision |
| 52 | + self.scales_precision = scales_precision |
| 53 | + |
| 54 | + if dtype is not None: |
| 55 | + raise ValueError("Please specify 'precision' instead of 'dtype'") |
| 56 | + |
| 57 | + assert out_features % 8 == 0, "require out_features % 8 == 0" |
| 58 | + assert ( |
| 59 | + in_features % (inner_k_tiles * 16) == 0 |
| 60 | + ), "require in_features % (innerKTiles * 16) == 0" |
| 61 | + # In the original implementation, the weight buffer is registered with the packed |
| 62 | + # sizes, i.e. the result of calling the _convert_weight_to_int4pack operator. |
| 63 | + # However, the Vulkan implementation does not expect the weights to be packed |
| 64 | + # therefore the weight tensor is registered with the unpacked sizes instead. |
| 65 | + # Note that in_features is divided by 2 because each `uint8` tensor element |
| 66 | + # contains 2 4-bit packed values. |
| 67 | + self.register_buffer( |
| 68 | + "weight", |
| 69 | + torch.empty( |
| 70 | + (out_features, in_features // 2), |
| 71 | + dtype=torch.uint8, |
| 72 | + device=device, |
| 73 | + ), |
| 74 | + ) |
| 75 | + self.dtype = dtype |
| 76 | + self.register_buffer( |
| 77 | + "scales_and_zeros", |
| 78 | + torch.empty( |
| 79 | + (in_features // groupsize, out_features, 2), |
| 80 | + dtype=self.scales_precision, |
| 81 | + device=device, |
| 82 | + ), |
| 83 | + ) |
| 84 | + |
| 85 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 86 | + if self.padding: |
| 87 | + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) |
| 88 | + # The forward method is replaced. In the original implementation, the forward |
| 89 | + # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom |
| 90 | + # operator is called instead. |
| 91 | + return torch.ops.et_vk.linear_weight_int4( |
| 92 | + input, |
| 93 | + self.weight, |
| 94 | + self.groupsize, |
| 95 | + self.scales_and_zeros, |
| 96 | + self.inner_k_tiles, |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +# This function is coped from torchao.quantization.GPTQ._replace_linear_int4 |
| 101 | +# with small changes at the annotated locations. |
| 102 | +def _vk_replace_linear_int4( |
| 103 | + module: torch.nn.Module, |
| 104 | + groupsize: int, |
| 105 | + inner_k_tiles: Optional[int], |
| 106 | + padding_allowed: bool, |
| 107 | + skip_layer_func: Optional[Callable] = None, |
| 108 | + precision: torch.dtype = torch.bfloat16, |
| 109 | + scales_precision: torch.dtype = torch.bfloat16, |
| 110 | + # Use custom vulkan linear layer as default |
| 111 | + linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, |
| 112 | + copy_weights: bool = False, |
| 113 | + # Serves the same purpose as `tensor_dim_limit` in |
| 114 | + # executorch.backends.vulkan.partitioner.VulkanSupportedOperators |
| 115 | + feature_limit: int = 16384, |
| 116 | +): |
| 117 | + for name, child in module.named_children(): |
| 118 | + if isinstance(child, torch.nn.Linear) and ( |
| 119 | + skip_layer_func is None or not skip_layer_func(child.weight) |
| 120 | + ): |
| 121 | + # Add an additional condition that the out/in features must not exceed the |
| 122 | + # `feature_limit` argument. |
| 123 | + if ( |
| 124 | + _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) |
| 125 | + or padding_allowed |
| 126 | + ) and ( |
| 127 | + child.out_features < feature_limit and child.in_features < feature_limit |
| 128 | + ): |
| 129 | + new_linear = linear_class( |
| 130 | + child.in_features, |
| 131 | + child.out_features, |
| 132 | + bias=False, |
| 133 | + device=child.weight.device, |
| 134 | + groupsize=groupsize, |
| 135 | + inner_k_tiles=inner_k_tiles, |
| 136 | + precision=precision, |
| 137 | + scales_precision=scales_precision, |
| 138 | + ) |
| 139 | + if copy_weights and child.weight.device != torch.device("meta"): |
| 140 | + new_linear.weight = child.weight |
| 141 | + setattr(module, name, new_linear) |
| 142 | + else: |
| 143 | + _vk_replace_linear_int4( |
| 144 | + child, |
| 145 | + groupsize, |
| 146 | + inner_k_tiles, |
| 147 | + padding_allowed, |
| 148 | + skip_layer_func, |
| 149 | + precision, |
| 150 | + scales_precision, |
| 151 | + linear_class, |
| 152 | + copy_weights, |
| 153 | + ) |
| 154 | + |
| 155 | + |
| 156 | +# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer |
| 157 | +# with some changes at the annotated lines. |
| 158 | +class VkInt4WeightOnlyQuantizer(Quantizer): |
| 159 | + def __init__( |
| 160 | + self, |
| 161 | + groupsize: int = 256, |
| 162 | + padding_allowed: bool = True, |
| 163 | + inner_k_tiles: Optional[int] = 8, |
| 164 | + device: torch.device = torch.device("cpu"), # noqa |
| 165 | + precision: torch.dtype = torch.float32, |
| 166 | + feature_limit: int = 16384, |
| 167 | + ) -> None: |
| 168 | + super().__init__() |
| 169 | + assert inner_k_tiles in [2, 4, 8] |
| 170 | + assert groupsize in [32, 64, 128, 256] |
| 171 | + |
| 172 | + self.inner_k_tiles = inner_k_tiles |
| 173 | + self.groupsize: int = groupsize |
| 174 | + self.padding_allowed: bool = padding_allowed |
| 175 | + self.device: torch.device = device |
| 176 | + self.precision: torch.dtype = precision |
| 177 | + # Serves the same purpose as `tensor_dim_limit` in |
| 178 | + # executorch.backends.vulkan.partitioner.VulkanSupportedOperators |
| 179 | + self.feature_limit = feature_limit |
| 180 | + |
| 181 | + @torch.no_grad() |
| 182 | + def _create_quantized_state_dict( |
| 183 | + self, model: torch.nn.Module |
| 184 | + ) -> Dict[str, torch.Tensor]: |
| 185 | + cur_state_dict = model.state_dict() |
| 186 | + for fqn, mod in model.named_modules(): |
| 187 | + # Add additional check to make sure features do not exceed feature limit |
| 188 | + if isinstance(mod, torch.nn.Linear) and ( |
| 189 | + mod.out_features < self.feature_limit |
| 190 | + and mod.in_features < self.feature_limit |
| 191 | + ): |
| 192 | + assert not mod.bias |
| 193 | + out_features = mod.out_features |
| 194 | + in_features = mod.in_features |
| 195 | + logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") |
| 196 | + |
| 197 | + assert ( |
| 198 | + in_features % self.groupsize == 0 |
| 199 | + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" |
| 200 | + |
| 201 | + weight = mod.weight.data |
| 202 | + if not _check_linear_int4_k( |
| 203 | + in_features, self.groupsize, self.inner_k_tiles |
| 204 | + ): |
| 205 | + if self.padding_allowed: |
| 206 | + import torch.nn.functional as F |
| 207 | + |
| 208 | + from torchao.quantization.utils import find_multiple |
| 209 | + |
| 210 | + logging.warn( |
| 211 | + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" |
| 212 | + ) |
| 213 | + padded_in_features = find_multiple(in_features, (1024,)) |
| 214 | + weight = F.pad( |
| 215 | + weight, pad=(0, padded_in_features - in_features) |
| 216 | + ) |
| 217 | + else: |
| 218 | + logging.warn( |
| 219 | + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " |
| 220 | + + "and that groupsize and inner_k_tiles*16 evenly divide into it" |
| 221 | + ) |
| 222 | + continue |
| 223 | + (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( |
| 224 | + weight, |
| 225 | + 4, # n_bit |
| 226 | + self.groupsize, |
| 227 | + self.precision, # dtype for scales_and_zeros |
| 228 | + ) |
| 229 | + # In the original implementation, w_int4x8 is packed via calling the |
| 230 | + # _convert_weight_to_int4pack operator before storing the weight. However |
| 231 | + # the Vulkan implementation does not expect the weights to be packed, so |
| 232 | + # the w_int4x8 tensor is stored as the weight instead. |
| 233 | + cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device) |
| 234 | + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( |
| 235 | + self.device |
| 236 | + ) |
| 237 | + return cur_state_dict |
| 238 | + |
| 239 | + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: |
| 240 | + _vk_replace_linear_int4( |
| 241 | + model, |
| 242 | + self.groupsize, |
| 243 | + self.inner_k_tiles, |
| 244 | + self.padding_allowed, |
| 245 | + skip_layer_func=None, |
| 246 | + precision=self.precision, |
| 247 | + scales_precision=self.precision, |
| 248 | + ) |
| 249 | + return model |
| 250 | + |
| 251 | + def quantize( |
| 252 | + self, model: torch.nn.Module, *args: Any, **kwargs: Any |
| 253 | + ) -> torch.nn.Module: |
| 254 | + state_dict = self._create_quantized_state_dict(model) |
| 255 | + model = self._convert_for_runtime(model) |
| 256 | + model.load_state_dict(state_dict, strict=False) |
| 257 | + return model |
0 commit comments