diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 6e11c36bfb0..40212c35c27 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -14,11 +14,12 @@ import torch from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import ( _convert_scalars_to_attrs, + bits_to_range, OP_TO_ANNOTATOR, propagate_annotation, ) from torch.fx import Node -from torchao.quantization.pt2e import PerChannelMinMaxObserver +from torchao.quantization.pt2e import PerChannelMinMaxObserver, PlaceholderObserver from torchao.quantization.pt2e.quantizer import ( QuantizationConfig, QuantizationSpec, @@ -28,50 +29,86 @@ __all__ = [ "VulkanQuantizer", - "get_linear_weight_qcs_qspec", - "get_linear_weight_only_qcs_xnn_qconfig", + "get_symmetric_quantization_config", ] -def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec: +@functools.lru_cache +def get_symmetric_quantization_config( + is_dynamic: bool = False, + weight_bits: int = 8, + act_bits: int = 8, + act_qmin: Optional[int] = None, + act_qmax: Optional[int] = None, + weight_qmin: Optional[int] = None, + weight_qmax: Optional[int] = None, +) -> QuantizationConfig: """ - Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization - of weight tensors of linear layers to the number of bits specified by quant_bits. + Return a QuantizationConfig for Vulkan quantizer. + + Args: + is_dynamic: If False, weight-only quantization. If True, dynamic quantization (activation + weight) + weight_bits: Number of bits for weight quantization (4 or 8) + act_bits: Number of bits for activation quantization (8) + act_qmin: Minimum quantization value for activations (auto-calculated if None) + act_qmax: Maximum quantization value for activations (auto-calculated if None) + weight_qmin: Minimum quantization value for weights (auto-calculated if None) + weight_qmax: Maximum quantization value for weights (auto-calculated if None) """ - weight_observer = PerChannelMinMaxObserver - assert quant_bits in { + assert weight_bits in { 8, 4, - }, f"Unsupported weight quantization bits: {quant_bits}" + }, f"Unsupported weight quantization bits: {weight_bits}" + + assert act_bits in { + 8, + }, f"Unsupported activation quantization bits: {act_bits}" - quant_min = -(2 ** (quant_bits - 1)) - quant_max = 2 ** (quant_bits - 1) - 1 - qscheme = torch.per_channel_symmetric + # Auto-calculate weight ranges if not provided + if weight_qmin is None or weight_qmax is None: + weight_range = bits_to_range(weight_bits) + weight_qmin = weight_qmin if weight_qmin is not None else weight_range[0] + weight_qmax = weight_qmax if weight_qmax is not None else weight_range[1] - return QuantizationSpec( + # Weight quantization: per-channel symmetric for Vulkan + weight_quantization_spec = QuantizationSpec( dtype=torch.int8, - quant_min=quant_min, - quant_max=quant_max, - qscheme=qscheme, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=torch.per_channel_symmetric, ch_axis=0, is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, ) - -@functools.lru_cache -def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig: - """ - Return a XNNPACKQuantizer QuantizationConfig class instance that specifies - quantizing the weight tensors of linear layers using per-channel symmetric (qcs) - quantization to the number of bits specified by quant_bits. - """ - weight_qspec = get_linear_weight_qcs_qspec(quant_bits) + # Configure activation quantization based on is_dynamic + if not is_dynamic: + # Weight-only quantization: no activation quantization + act_quantization_spec = None + output_activation_spec = None + else: + # Dynamic quantization: per-token input quantization, no output quantization + # Auto-calculate activation ranges if not provided + if act_qmin is None or act_qmax is None: + act_range = bits_to_range(act_bits) + act_qmin = act_qmin if act_qmin is not None else act_range[0] + act_qmax = act_qmax if act_qmax is not None else act_range[1] + + act_observer_or_fake_quant_ctr = PlaceholderObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=True, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr, + ) + output_activation_spec = None return QuantizationConfig( - input_activation=None, - output_activation=None, - weight=weight_qspec, + input_activation=act_quantization_spec, + output_activation=output_activation_spec, + weight=weight_quantization_spec, bias=None, is_qat=False, ) @@ -99,12 +136,11 @@ def transform_for_annotation( return _convert_scalars_to_attrs(model) def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - # currently only support static quant on Vulkan - model = self._annotate_for_static_quantization_config(model) + model = self._annotate_for_quantization_config(model) propagate_annotation(model) return model - def _annotate_all_static_patterns( + def _annotate_all_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], @@ -117,10 +153,10 @@ def _annotate_all_static_patterns( OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) return model - def _annotate_for_static_quantization_config( + def _annotate_for_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - self._annotate_all_static_patterns( + self._annotate_all_patterns( model, self.global_config, ) diff --git a/backends/vulkan/quantizer/vulkan_quantizer_utils.py b/backends/vulkan/quantizer/vulkan_quantizer_utils.py index 7fa549b57cb..c0b6ab39e84 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer_utils.py +++ b/backends/vulkan/quantizer/vulkan_quantizer_utils.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch from torch.fx import Node @@ -27,9 +27,26 @@ "OP_TO_ANNOTATOR", "propagate_annotation", "_convert_scalars_to_attrs", + "bits_to_range", ] +def bits_to_range(bits: int) -> Tuple[int, int]: + """ + Calculate quantization range for given number of bits. + + Args: + bits: Number of quantization bits + + Returns: + Tuple of (qmin, qmax) for the given bit width + """ + return ( + -(2 ** (bits - 1)), + (2 ** (bits - 1) - 1), + ) + + AnnotatorType = Callable[ [ torch.fx.GraphModule, diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index ff9e2d85a96..0de1fd5d69a 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -7,7 +7,7 @@ from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_linear_weight_only_qcs_xnn_qconfig, + get_symmetric_quantization_config, VulkanQuantizer, ) @@ -101,7 +101,9 @@ def test_fuse_int8pack_mm(self): sample_inputs = model.get_sample_inputs() quantizer = VulkanQuantizer() - quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8)) + quantizer.set_global( + get_symmetric_quantization_config(is_dynamic=False, weight_bits=8) + ) edge_manager = quantize_and_lower_module( model, @@ -129,7 +131,9 @@ def test_fuse_linear_qcs4w(self): sample_inputs = model.get_sample_inputs() quantizer = VulkanQuantizer() - quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4)) + quantizer.set_global( + get_symmetric_quantization_config(is_dynamic=False, weight_bits=4) + ) edge_manager = quantize_and_lower_module( model, diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index b94feb5a1ae..d87c722363f 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -12,7 +12,7 @@ import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, + get_symmetric_quantization_config as get_symmetric_quantization_config_xnnpack, XNNPACKQuantizer, ) @@ -127,11 +127,11 @@ def check_embedding_byte_registered(): "At the moment only per channel weight quantization is supported." ) if quant_params.quantize_linear.is_qc4: - operator_config_dynamic = get_symmetric_quantization_config( + operator_config_dynamic = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7 ) else: - operator_config_dynamic = get_symmetric_quantization_config( + operator_config_dynamic = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=True ) dynamic_quantizer.set_global(operator_config_dynamic) @@ -247,13 +247,13 @@ def get_coreml_quantizer(pt2e_quantize: str): raise NotImplementedError("4-bit Core ML quantizer is still under development") elif pt2e_quantize == "coreml_baseline_8a_c8w": - config = get_symmetric_quantization_config( + config = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=False ) quantizer = XNNPACKQuantizer().set_global(config) elif pt2e_quantize == "coreml_baseline_8a_c4w": - config = get_symmetric_quantization_config( + config = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7 ) quantizer = XNNPACKQuantizer().set_global(config) @@ -266,12 +266,14 @@ def get_coreml_quantizer(pt2e_quantize: str): def get_vulkan_quantizer(pt2e_quantize: str): from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_linear_weight_only_qcs_xnn_qconfig, + get_symmetric_quantization_config as get_symmetric_quantization_config_vulkan, VulkanQuantizer, ) if pt2e_quantize == "vulkan_8w": - config = get_linear_weight_only_qcs_xnn_qconfig(8) + config = get_symmetric_quantization_config_vulkan( + is_dynamic=False, weight_bits=8 + ) else: raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")