diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 9e754b25a62..0b7a756bf0d 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -41,6 +41,20 @@ runtime.python_library( ], ) +runtime.python_library( + name = "int4_weight_only_quantizer", + srcs = [ + "int4_weight_only_quantizer.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + ":custom_ops_defs", + "//pytorch/ao:torchao", + ] +) + runtime.python_library( name = "vulkan_passes", srcs = [ @@ -50,6 +64,7 @@ runtime.python_library( "//executorch/backends/...", ], deps = [ + ":int4_weight_only_quantizer", ":remove_local_scalar_dense", ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index bded91094e7..080df836080 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -1,7 +1,11 @@ +from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( + VkInt4WeightOnlyQuantizer, +) from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) __all__ = [ + "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", ] diff --git a/backends/vulkan/_passes/custom_ops_defs.py b/backends/vulkan/_passes/custom_ops_defs.py index fd586b665a0..2c16a331c04 100644 --- a/backends/vulkan/_passes/custom_ops_defs.py +++ b/backends/vulkan/_passes/custom_ops_defs.py @@ -9,6 +9,10 @@ namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") +##################### +## conv_with_clamp ## +##################### + def conv_with_clamp_impl( input, @@ -47,6 +51,10 @@ def conv_with_clamp_impl( lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd") conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) +######################### +## conv_with_clamp.out ## +######################### + def conv_with_clamp_out_impl( input, @@ -84,6 +92,10 @@ def conv_with_clamp_out_impl( ) lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd") +################# +## grid_priors ## +################# + # The dimension of x should be larger than 1 def grid_priors_impl( @@ -125,3 +137,35 @@ def grid_priors_out_impl( f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd") + +######################## +## linear_weight_int4 ## +######################## + + +def linear_weight_int4_impl( + x: torch.Tensor, + weights_4x8: torch.Tensor, + groupsize: int, + scales_and_zeros: torch.Tensor, + inner_k_tiles: int, +): + original_x_size = x.size() + out_features = weights_4x8.size(0) + x = x.reshape(-1, original_x_size[-1]) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weights_4x8, inner_k_tiles + ) + out = torch.ops.aten._weight_int4pack_mm( + x, weight_int4pack, groupsize, scales_and_zeros + ) + out_shape = original_x_size[:-1] + (out_features,) + return out.reshape(out_shape) + + +name = "linear_weight_int4" +lib.define( + f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor" +) +lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") +linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py new file mode 100644 index 00000000000..a0d208bb63f --- /dev/null +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -0,0 +1,257 @@ +import logging +from typing import Any, Callable, Dict, Optional, Type + +import torch +import torch.nn.functional as F + +from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa + linear_weight_int4_op, +) + +from torchao.quantization.GPTQ import _check_linear_int4_k +from torchao.quantization.unified import Quantizer +from torchao.quantization.utils import groupwise_affine_quantize_tensor + + +# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with +# changes at the annotated lines. +class VkWeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + # TODO: remove dtype field, not used + bias=False, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: + from torchao.quantization.utils import find_multiple + + self.origin_in_features = in_features + in_features = find_multiple(in_features, (1024,)) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.device = device + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + + if dtype is not None: + raise ValueError("Please specify 'precision' instead of 'dtype'") + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" + # In the original implementation, the weight buffer is registered with the packed + # sizes, i.e. the result of calling the _convert_weight_to_int4pack operator. + # However, the Vulkan implementation does not expect the weights to be packed + # therefore the weight tensor is registered with the unpacked sizes instead. + # Note that in_features is divided by 2 because each `uint8` tensor element + # contains 2 4-bit packed values. + self.register_buffer( + "weight", + torch.empty( + (out_features, in_features // 2), + dtype=torch.uint8, + device=device, + ), + ) + self.dtype = dtype + self.register_buffer( + "scales_and_zeros", + torch.empty( + (in_features // groupsize, out_features, 2), + dtype=self.scales_precision, + device=device, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.padding: + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + # The forward method is replaced. In the original implementation, the forward + # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom + # operator is called instead. + return torch.ops.et_vk.linear_weight_int4( + input, + self.weight, + self.groupsize, + self.scales_and_zeros, + self.inner_k_tiles, + ) + + +# This function is coped from torchao.quantization.GPTQ._replace_linear_int4 +# with small changes at the annotated locations. +def _vk_replace_linear_int4( + module: torch.nn.Module, + groupsize: int, + inner_k_tiles: Optional[int], + padding_allowed: bool, + skip_layer_func: Optional[Callable] = None, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + # Use custom vulkan linear layer as default + linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, + copy_weights: bool = False, + # Serves the same purpose as `tensor_dim_limit` in + # executorch.backends.vulkan.partitioner.VulkanSupportedOperators + feature_limit: int = 16384, +): + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear) and ( + skip_layer_func is None or not skip_layer_func(child.weight) + ): + # Add an additional condition that the out/in features must not exceed the + # `feature_limit` argument. + if ( + _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) + or padding_allowed + ) and ( + child.out_features < feature_limit and child.in_features < feature_limit + ): + new_linear = linear_class( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=precision, + scales_precision=scales_precision, + ) + if copy_weights and child.weight.device != torch.device("meta"): + new_linear.weight = child.weight + setattr(module, name, new_linear) + else: + _vk_replace_linear_int4( + child, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + precision, + scales_precision, + linear_class, + copy_weights, + ) + + +# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer +# with some changes at the annotated lines. +class VkInt4WeightOnlyQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = True, + inner_k_tiles: Optional[int] = 8, + device: torch.device = torch.device("cpu"), # noqa + precision: torch.dtype = torch.float32, + feature_limit: int = 16384, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + + self.inner_k_tiles = inner_k_tiles + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.device: torch.device = device + self.precision: torch.dtype = precision + # Serves the same purpose as `tensor_dim_limit` in + # executorch.backends.vulkan.partitioner.VulkanSupportedOperators + self.feature_limit = feature_limit + + @torch.no_grad() + def _create_quantized_state_dict( + self, model: torch.nn.Module + ) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + # Add additional check to make sure features do not exceed feature limit + if isinstance(mod, torch.nn.Linear) and ( + mod.out_features < self.feature_limit + and mod.in_features < self.feature_limit + ): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding_allowed: + import torch.nn.functional as F + + from torchao.quantization.utils import find_multiple + + logging.warn( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = find_multiple(in_features, (1024,)) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + logging.warn( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) + continue + (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( + weight, + 4, # n_bit + self.groupsize, + self.precision, # dtype for scales_and_zeros + ) + # In the original implementation, w_int4x8 is packed via calling the + # _convert_weight_to_int4pack operator before storing the weight. However + # the Vulkan implementation does not expect the weights to be packed, so + # the w_int4x8 tensor is stored as the weight instead. + cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device) + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( + self.device + ) + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + _vk_replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + skip_layer_func=None, + precision=self.precision, + scales_precision=self.precision, + ) + return model + + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + model.load_state_dict(state_dict, strict=False) + return model diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 6bc568bfdd7..da50719ba33 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -83,6 +83,7 @@ def __contains__(self, op): exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.linear.default, + exir_ops.edge.et_vk.linear_weight_int4.default, # Reduction exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default,