From 8abb889ab28ff14a16b4024788354faea14c53ab Mon Sep 17 00:00:00 2001 From: morelos Date: Tue, 29 Jul 2025 23:43:25 -0700 Subject: [PATCH 1/4] [ET-VK] Migrate off of xnnpack_quantizer_utils Pull Request resolved: https://github.com/pytorch/executorch/pull/12572 # Context Eventually as the vulkan_quantizer file expands, we will need to migrate into a custom utils file and stop depending on the xnnpack_quantizer_utils. We migrate only the minimal amount of functions necessary to ensure the vulkan_quantizer works. # Changes We create a new file `vulkan_quantizer_utils.py` and migrate off of `xnnpack_quantizer_utils.py` in `vulkan_quantizer`. There are no specific modifications necessary to work separate from xnnpack utils except bits_to_range to allow not needing to specify the ranges everytime. ghstack-source-id: 299473612 @exported-using-ghexport Differential Revision: [D78290055](https://our.internmc.facebook.com/intern/diff/D78290055/) --- backends/vulkan/quantizer/TARGETS | 12 +- backends/vulkan/quantizer/vulkan_quantizer.py | 2 +- .../quantizer/vulkan_quantizer_utils.py | 206 ++++++++++++++++++ 3 files changed, 216 insertions(+), 4 deletions(-) create mode 100644 backends/vulkan/quantizer/vulkan_quantizer_utils.py diff --git a/backends/vulkan/quantizer/TARGETS b/backends/vulkan/quantizer/TARGETS index 5650f2bd728..2c3ae37923a 100644 --- a/backends/vulkan/quantizer/TARGETS +++ b/backends/vulkan/quantizer/TARGETS @@ -4,11 +4,17 @@ oncall("executorch") python_library( name = "vulkan_quantizer", - srcs = [ - "vulkan_quantizer.py", + srcs = ["vulkan_quantizer.py"], + deps = [ + ":vulkan_quantizer_utils", + "//caffe2:torch", ], +) + +python_library( + name = "vulkan_quantizer_utils", + srcs = ["vulkan_quantizer_utils.py"], deps = [ "//caffe2:torch", - "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils", ], ) diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index a82c2091cf6..6e11c36bfb0 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -12,7 +12,7 @@ from typing import Callable, Optional import torch -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( +from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import ( _convert_scalars_to_attrs, OP_TO_ANNOTATOR, propagate_annotation, diff --git a/backends/vulkan/quantizer/vulkan_quantizer_utils.py b/backends/vulkan/quantizer/vulkan_quantizer_utils.py new file mode 100644 index 00000000000..7fa549b57cb --- /dev/null +++ b/backends/vulkan/quantizer/vulkan_quantizer_utils.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Callable, Optional + +import torch +from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, + get_bias_qspec, + get_input_act_qspec, + get_output_act_qspec, + get_weight_qspec, + QuantizationAnnotation, + QuantizationConfig, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix + +__all__ = [ + "OP_TO_ANNOTATOR", + "propagate_annotation", + "_convert_scalars_to_attrs", +] + + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[list[list[Node]]], +] +OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {} + + +def register_annotator(op: str) -> Callable[[AnnotatorType], None]: + def decorator(annotator: AnnotatorType) -> None: + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +def _is_annotated(nodes: list[Node]) -> bool: + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: list[Node]) -> None: + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable[..., torch.Tensor]) -> bool: + return op in [ + torch.ops.aten.relu.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model From 9c28a88a1c2a4c8a527a957360a197a27a15aefa Mon Sep 17 00:00:00 2001 From: morelos Date: Tue, 29 Jul 2025 23:43:26 -0700 Subject: [PATCH 2/4] [ET-VK] Creating get_symmetric_quantization_config Pull Request resolved: https://github.com/pytorch/executorch/pull/12573 # Context Eventually dynamic quantization will be enabled in the vulkan_quantizer (particularly 8bit dyn act with 8bit weights). In order to enable this functionality we need to utilize a similar method as XNNPack with how they define their quantization config. This diff aims to align with XNNPack quantizer logic and also migrate away from utilizing the old static quantization config logic. # Changes A few noticable changes is that we migrate away from `get_linear_weight_only_qcs_xnn_qconfig`, and we now define a symmetric config that has parameters to define whether it's dynamically quantized or not. Furthermore, we also incorporate bits_to_range so that we can automatically designate the min and max quant ranges without having to set them during initialization. We also change some wording from using just static as we are now enabling dynamic quantization as well. Furthermore, we change internally other codebases that are calling our existing legacy config, and move them into the more universal symmetric config. Since this follows the same naming scheme as XNNPack, I have decided to just add aliases in cases where its being imported directly along with XNNPack. ghstack-source-id: 299473613 @exported-using-ghexport Differential Revision: [D78291249](https://our.internmc.facebook.com/intern/diff/D78291249/) --- backends/vulkan/quantizer/vulkan_quantizer.py | 104 ++++++++++++------ .../quantizer/vulkan_quantizer_utils.py | 19 +++- backends/vulkan/test/test_vulkan_passes.py | 10 +- extension/llm/export/quantizer_lib.py | 16 +-- 4 files changed, 104 insertions(+), 45 deletions(-) diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 6e11c36bfb0..40212c35c27 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -14,11 +14,12 @@ import torch from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import ( _convert_scalars_to_attrs, + bits_to_range, OP_TO_ANNOTATOR, propagate_annotation, ) from torch.fx import Node -from torchao.quantization.pt2e import PerChannelMinMaxObserver +from torchao.quantization.pt2e import PerChannelMinMaxObserver, PlaceholderObserver from torchao.quantization.pt2e.quantizer import ( QuantizationConfig, QuantizationSpec, @@ -28,50 +29,86 @@ __all__ = [ "VulkanQuantizer", - "get_linear_weight_qcs_qspec", - "get_linear_weight_only_qcs_xnn_qconfig", + "get_symmetric_quantization_config", ] -def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec: +@functools.lru_cache +def get_symmetric_quantization_config( + is_dynamic: bool = False, + weight_bits: int = 8, + act_bits: int = 8, + act_qmin: Optional[int] = None, + act_qmax: Optional[int] = None, + weight_qmin: Optional[int] = None, + weight_qmax: Optional[int] = None, +) -> QuantizationConfig: """ - Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization - of weight tensors of linear layers to the number of bits specified by quant_bits. + Return a QuantizationConfig for Vulkan quantizer. + + Args: + is_dynamic: If False, weight-only quantization. If True, dynamic quantization (activation + weight) + weight_bits: Number of bits for weight quantization (4 or 8) + act_bits: Number of bits for activation quantization (8) + act_qmin: Minimum quantization value for activations (auto-calculated if None) + act_qmax: Maximum quantization value for activations (auto-calculated if None) + weight_qmin: Minimum quantization value for weights (auto-calculated if None) + weight_qmax: Maximum quantization value for weights (auto-calculated if None) """ - weight_observer = PerChannelMinMaxObserver - assert quant_bits in { + assert weight_bits in { 8, 4, - }, f"Unsupported weight quantization bits: {quant_bits}" + }, f"Unsupported weight quantization bits: {weight_bits}" + + assert act_bits in { + 8, + }, f"Unsupported activation quantization bits: {act_bits}" - quant_min = -(2 ** (quant_bits - 1)) - quant_max = 2 ** (quant_bits - 1) - 1 - qscheme = torch.per_channel_symmetric + # Auto-calculate weight ranges if not provided + if weight_qmin is None or weight_qmax is None: + weight_range = bits_to_range(weight_bits) + weight_qmin = weight_qmin if weight_qmin is not None else weight_range[0] + weight_qmax = weight_qmax if weight_qmax is not None else weight_range[1] - return QuantizationSpec( + # Weight quantization: per-channel symmetric for Vulkan + weight_quantization_spec = QuantizationSpec( dtype=torch.int8, - quant_min=quant_min, - quant_max=quant_max, - qscheme=qscheme, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=torch.per_channel_symmetric, ch_axis=0, is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, ) - -@functools.lru_cache -def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig: - """ - Return a XNNPACKQuantizer QuantizationConfig class instance that specifies - quantizing the weight tensors of linear layers using per-channel symmetric (qcs) - quantization to the number of bits specified by quant_bits. - """ - weight_qspec = get_linear_weight_qcs_qspec(quant_bits) + # Configure activation quantization based on is_dynamic + if not is_dynamic: + # Weight-only quantization: no activation quantization + act_quantization_spec = None + output_activation_spec = None + else: + # Dynamic quantization: per-token input quantization, no output quantization + # Auto-calculate activation ranges if not provided + if act_qmin is None or act_qmax is None: + act_range = bits_to_range(act_bits) + act_qmin = act_qmin if act_qmin is not None else act_range[0] + act_qmax = act_qmax if act_qmax is not None else act_range[1] + + act_observer_or_fake_quant_ctr = PlaceholderObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=True, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr, + ) + output_activation_spec = None return QuantizationConfig( - input_activation=None, - output_activation=None, - weight=weight_qspec, + input_activation=act_quantization_spec, + output_activation=output_activation_spec, + weight=weight_quantization_spec, bias=None, is_qat=False, ) @@ -99,12 +136,11 @@ def transform_for_annotation( return _convert_scalars_to_attrs(model) def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - # currently only support static quant on Vulkan - model = self._annotate_for_static_quantization_config(model) + model = self._annotate_for_quantization_config(model) propagate_annotation(model) return model - def _annotate_all_static_patterns( + def _annotate_all_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], @@ -117,10 +153,10 @@ def _annotate_all_static_patterns( OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) return model - def _annotate_for_static_quantization_config( + def _annotate_for_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - self._annotate_all_static_patterns( + self._annotate_all_patterns( model, self.global_config, ) diff --git a/backends/vulkan/quantizer/vulkan_quantizer_utils.py b/backends/vulkan/quantizer/vulkan_quantizer_utils.py index 7fa549b57cb..c0b6ab39e84 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer_utils.py +++ b/backends/vulkan/quantizer/vulkan_quantizer_utils.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch from torch.fx import Node @@ -27,9 +27,26 @@ "OP_TO_ANNOTATOR", "propagate_annotation", "_convert_scalars_to_attrs", + "bits_to_range", ] +def bits_to_range(bits: int) -> Tuple[int, int]: + """ + Calculate quantization range for given number of bits. + + Args: + bits: Number of quantization bits + + Returns: + Tuple of (qmin, qmax) for the given bit width + """ + return ( + -(2 ** (bits - 1)), + (2 ** (bits - 1) - 1), + ) + + AnnotatorType = Callable[ [ torch.fx.GraphModule, diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index ff9e2d85a96..0de1fd5d69a 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -7,7 +7,7 @@ from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_linear_weight_only_qcs_xnn_qconfig, + get_symmetric_quantization_config, VulkanQuantizer, ) @@ -101,7 +101,9 @@ def test_fuse_int8pack_mm(self): sample_inputs = model.get_sample_inputs() quantizer = VulkanQuantizer() - quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8)) + quantizer.set_global( + get_symmetric_quantization_config(is_dynamic=False, weight_bits=8) + ) edge_manager = quantize_and_lower_module( model, @@ -129,7 +131,9 @@ def test_fuse_linear_qcs4w(self): sample_inputs = model.get_sample_inputs() quantizer = VulkanQuantizer() - quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4)) + quantizer.set_global( + get_symmetric_quantization_config(is_dynamic=False, weight_bits=4) + ) edge_manager = quantize_and_lower_module( model, diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index b94feb5a1ae..d87c722363f 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -12,7 +12,7 @@ import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, + get_symmetric_quantization_config as get_symmetric_quantization_config_xnnpack, XNNPACKQuantizer, ) @@ -127,11 +127,11 @@ def check_embedding_byte_registered(): "At the moment only per channel weight quantization is supported." ) if quant_params.quantize_linear.is_qc4: - operator_config_dynamic = get_symmetric_quantization_config( + operator_config_dynamic = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7 ) else: - operator_config_dynamic = get_symmetric_quantization_config( + operator_config_dynamic = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=True ) dynamic_quantizer.set_global(operator_config_dynamic) @@ -247,13 +247,13 @@ def get_coreml_quantizer(pt2e_quantize: str): raise NotImplementedError("4-bit Core ML quantizer is still under development") elif pt2e_quantize == "coreml_baseline_8a_c8w": - config = get_symmetric_quantization_config( + config = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=False ) quantizer = XNNPACKQuantizer().set_global(config) elif pt2e_quantize == "coreml_baseline_8a_c4w": - config = get_symmetric_quantization_config( + config = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7 ) quantizer = XNNPACKQuantizer().set_global(config) @@ -266,12 +266,14 @@ def get_coreml_quantizer(pt2e_quantize: str): def get_vulkan_quantizer(pt2e_quantize: str): from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_linear_weight_only_qcs_xnn_qconfig, + get_symmetric_quantization_config as get_symmetric_quantization_config_vulkan, VulkanQuantizer, ) if pt2e_quantize == "vulkan_8w": - config = get_linear_weight_only_qcs_xnn_qconfig(8) + config = get_symmetric_quantization_config_vulkan( + is_dynamic=False, weight_bits=8 + ) else: raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}") From b7b06046b703effeb117f6a440716bc3ef805fd2 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 30 Jul 2025 13:09:49 -0400 Subject: [PATCH 3/4] [ET-VK] linear_qta8a_qga4w graph pass (#13000) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: https://github.com/pytorch/executorch/pull/12574 by @ahmtox ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/ahmtox/42/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/42/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/ahmtox/41/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/42/orig @diff-train-skip-merge --------- Co-authored-by: morelos Co-authored-by: ahmtox <69552192+ahmtox@users.noreply.github.com> --- backends/vulkan/_passes/fuse_quantized_ops.py | 283 ++++++ backends/vulkan/custom_ops_lib.py | 89 ++ backends/vulkan/op_registry.py | 7 +- .../graph/ops/glsl/dequantize_buffer.glsl | 152 ++-- .../graph/ops/glsl/dequantize_buffer.yaml | 2 + .../graph/ops/glsl/dequantize_texture.glsl | 46 +- .../graph/ops/glsl/dequantize_texture.yaml | 2 + .../graph/ops/glsl/quantize_buffer.glsl | 142 +-- .../graph/ops/glsl/quantize_buffer.yaml | 2 + .../graph/ops/glsl/quantize_texture.glsl | 157 ++-- .../graph/ops/glsl/quantize_texture.yaml | 2 + .../runtime/graph/ops/impl/Dequantize.cpp | 301 ++++-- .../runtime/graph/ops/impl/Quantize.cpp | 196 +++- .../ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp | 9 +- backends/vulkan/test/TARGETS | 1 + .../test/op_tests/quantize_affine_test.cpp | 859 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 6 + backends/vulkan/test/test_vulkan_passes.py | 54 ++ backends/vulkan/utils.py | 15 + 19 files changed, 2012 insertions(+), 313 deletions(-) create mode 100644 backends/vulkan/test/op_tests/quantize_affine_test.cpp diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index 805a5c1f744..aa4829d9c90 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -210,6 +210,278 @@ def fuse_into_linear_qcnw_node( graph_module.graph.erase_node(dq_weight_node) +######################### +## linear_qta8a_qga4w ## +######################### + + +def _is_dequantize_affine_node(node: torch.fx.Node) -> bool: + """Check if a node is a dequantize_affine operation.""" + return ( + node.op == "call_function" + and node.target is not None + and hasattr(node.target, "__name__") + and "dequantize_affine" in getattr(node.target, "__name__", "") + ) + + +def _is_view_copy_node(node: torch.fx.Node) -> bool: + """Check if a node is a view_copy operation.""" + return ( + node.op == "call_function" + and node.target is not None + and hasattr(node.target, "__name__") + and "view_copy" in getattr(node.target, "__name__", "") + ) + + +def _validate_qta8a_qga4w_nodes( + input_node: torch.fx.node.Argument, weight_node: torch.fx.node.Argument +) -> Optional[torch.fx.Node]: + """ + Validate input and weight nodes for QTA8A_QGA4W pattern. + Returns the actual input node (after handling view operations) or None if invalid. + """ + # Type checking - ensure we have torch.fx.Node objects + if not isinstance(weight_node, torch.fx.Node) or not isinstance( + input_node, torch.fx.Node + ): + return None + + # Input may be preprocessed with a view node + actual_input_node = input_node + if _is_view_copy_node(input_node): + actual_input_node = input_node.args[0] + if not isinstance(actual_input_node, torch.fx.Node): + return None + + # Check if input is dequantized with dequantize_affine (from dynamic quantization) + if not _is_dequantize_affine_node(actual_input_node): + return None + + # Check if weight is dequantized with dequantize_affine + if not _is_dequantize_affine_node(weight_node): + return None + + return actual_input_node + + +def _extract_weight_params( + program: ExportedProgram, weight_node: torch.fx.Node +) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]: + """Extract and validate weight parameters from dequantize_affine node.""" + # Get the original quantized weight and quantization parameters + if len(weight_node.args) < 4: + return None + + orig_weight = weight_node.args[0] + weight_scales = weight_node.args[2] + weight_zeros = weight_node.args[3] + + # Type checking + if not isinstance(orig_weight, torch.fx.Node) or not is_param_node( + program, orig_weight + ): + return None + if not isinstance(weight_scales, torch.fx.Node) or not is_param_node( + program, weight_scales + ): + return None + if not isinstance(weight_zeros, torch.fx.Node) or not is_param_node( + program, weight_zeros + ): + return None + + return orig_weight, weight_scales, weight_zeros + + +def _validate_4bit_quantization(weight_tensor: torch.Tensor) -> bool: + """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range).""" + quant_min = weight_tensor.min().item() + quant_max = weight_tensor.max().item() + return quant_min >= -8 and quant_max <= 7 + + +def _calculate_group_size( + orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor +) -> Optional[int]: + """Calculate and validate group size from weight and scales tensors.""" + out_features, in_features = orig_weight_tensor.shape + + if len(weight_scales_tensor.shape) != 2: + return None + + scales_out_features, num_groups = weight_scales_tensor.shape + + if scales_out_features != out_features: + return None + + group_size = in_features // num_groups + if in_features % group_size != 0: + return None + + return group_size + + +def matches_linear_qta8a_qga4w_pattern( + program: ExportedProgram, node: torch.fx.Node +) -> Optional[Tuple[int, int]]: + """ + Checks if the nodes surrounding a linear node matches the pattern for dynamic + activation + grouped weight quantized linear (QTA8A_QGA4W). + + This pattern involves: + 1. Dynamic quantization of input activations (8-bit) + 2. Grouped quantization of weights (4-bit with group size) + + The expected pattern from Int8DynActInt4WeightQuantizer is: + scale, zero_point = choose_qparams_affine(input) + quantized_input = quantize_affine(input, scale, zero_point) + dequantized_input = dequantize_affine(quantized_input, ...) + dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros) + output = linear(dequantized_input, dequantized_weight) + + If the pattern matches, return (group_size, weight_bits), otherwise None. + """ + if not utils.is_linear_node(node): + return None + + input_node = node.args[0] + weight_node = node.args[1] + + # Validate nodes and get actual input node + actual_input_node = _validate_qta8a_qga4w_nodes(input_node, weight_node) + if actual_input_node is None: + return None + + # Extract weight parameters + if not isinstance(weight_node, torch.fx.Node): + return None + weight_params = _extract_weight_params(program, weight_node) + if weight_params is None: + return None + + orig_weight, weight_scales, weight_zeros = weight_params + + # Get tensors to analyze the quantization scheme + orig_weight_tensor = get_param_tensor(program, orig_weight) + weight_scales_tensor = get_param_tensor(program, weight_scales) + weight_zeros_tensor = get_param_tensor(program, weight_zeros) + + if not isinstance(orig_weight_tensor, torch.Tensor): + return None + if not isinstance(weight_scales_tensor, torch.Tensor): + return None + if not isinstance(weight_zeros_tensor, torch.Tensor): + return None + + # Check if weight is quantized to 4 bits + if not _validate_4bit_quantization(orig_weight_tensor): + return None + + # Calculate group size + group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor) + if group_size is None: + return None + + # Verify this is 4-bit grouped quantization + weight_bits = 4 + + return group_size, weight_bits + + +def fuse_into_linear_qta8a_qga4w_node( + program: ExportedProgram, + graph_module: torch.fx.GraphModule, + linear_node: torch.fx.Node, + group_size: int, + weight_bits: int, +) -> None: + """ + Fuse the dynamic activation + grouped weight quantized linear pattern into + a single linear_qta8a_qga4w operator. + + The pattern: + dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...) + dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...) + output = linear(dequantized_input, dequantized_weight) + + Becomes: + output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point, + weight, group_size, weight_scales, weight_zeros) + """ + dq_input_node = linear_node.args[0] + dq_weight_node = linear_node.args[1] + + assert isinstance(dq_input_node, torch.fx.Node) + + input_view_node = None + # Input may be preprocessed with a view node + if ( + dq_input_node.op == "call_function" + and dq_input_node.target is not None + and hasattr(dq_input_node.target, "__name__") + and "view_copy" in getattr(dq_input_node.target, "__name__", "") + ): + input_view_node = dq_input_node + dq_input_node = dq_input_node.args[0] + assert isinstance(dq_input_node, torch.fx.Node) + + assert isinstance(dq_input_node, torch.fx.Node) + assert isinstance(dq_weight_node, torch.fx.Node) + + # Get the quantized input and quantization parameters from the input dequantize_affine node + # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype) + quantized_input = dq_input_node.args[0] + input_scale = dq_input_node.args[2] # scale is the 3rd argument + input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None + + # Get the weight and its quantization parameters from dequantize_affine + # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype) + orig_weight = dq_weight_node.args[0] + weight_scales = dq_weight_node.args[2] + weight_zeros = dq_weight_node.args[3] + + # Pack the 4-bit weight tensor for efficient storage + assert isinstance(orig_weight, torch.fx.Node) + orig_weight_tensor = get_param_tensor(program, orig_weight) + assert isinstance(orig_weight_tensor, torch.Tensor) + packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) + utils.update_program_state_dict( + program, + orig_weight.name, + packed_weight_tensor, + ) + # Update the metadata to reflect the new packed shape + orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) + + # Create the linear_qta8a_qga4w node + with graph_module.graph.inserting_before(linear_node): + linear_qta8a_qga4w_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_qta8a_qga4w.default, + ( + quantized_input, # quantized input (int8) + input_scale, # mat1_scale + input_zero_point, # mat1_zero_point + orig_weight, # mat2_data (packed 4-bit weights) + group_size, # group_size (int) + weight_scales, # weight_scales + weight_zeros, # weight_zeros + ), + ) + + # Replace the linear node with the new fused node + linear_node.replace_all_uses_with(linear_qta8a_qga4w_node) + + # Erase nodes in the correct order (users first, then dependencies) + graph_module.graph.erase_node(linear_node) + if input_view_node is not None: + graph_module.graph.erase_node(input_view_node) + graph_module.graph.erase_node(dq_weight_node) + graph_module.graph.erase_node(dq_input_node) + + class FuseQuantizedOpsTransform(ExportPass): def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() @@ -217,12 +489,23 @@ def __init__(self, exported_program: ExportedProgram) -> None: def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.nodes: + # Check for linear_qcnw pattern (weight-only quantization) qcnw_details = matches_linear_qcnw_pattern(self.program, node) if qcnw_details is not None: qcnw_method, qcnw_nbits = qcnw_details fuse_into_linear_qcnw_node( self.program, graph_module, node, qcnw_method, qcnw_nbits ) + continue + + # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization) + qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node) + if qta8a_qga4w_details is not None: + group_size, weight_bits = qta8a_qga4w_details + fuse_into_linear_qta8a_qga4w_node( + self.program, graph_module, node, group_size, weight_bits + ) + continue graph_module.recompile() dead_code_elimination_pass(graph_module) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index af6fcbfbb14..c9b884e5b86 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -231,6 +231,95 @@ def linear_qcs4w( lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd") linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) +######################## +## linear_qta8a_qga4w ## +######################## + + +def linear_qta8a_qga4w( + x_quantized: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weights_4bit: torch.Tensor, + group_size: int, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, +): + """ + Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W). + + Args: + x_quantized: Already quantized input tensor (int8, per-token quantized) + input_scale: Scale for per-token quantization of input (shape: [batch_size]) + input_zero_point: Zero point for per-token quantization of input (shape: [batch_size]) + weights_4bit: Packed 4-bit quantized weights + group_size: Group size for weight quantization (int) + weight_scales: Per-group scales for weights + weight_zeros: Per-group zero points for weights + """ + original_x_shape = x_quantized.shape + feature_dim = original_x_shape[-1] + + # Reshape for processing + x_quantized_2d = x_quantized.reshape(-1, feature_dim) + + # Unpack 4-bit weights + unpacked_weights_shape = weights_4bit.shape + out_features = unpacked_weights_shape[0] + in_features = unpacked_weights_shape[1] + + weights_unpacked = torch.empty( + (out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device + ) + + weights_unpacked[:, ::2] = weights_4bit >> 4 + weights_unpacked[:, 1::2] = weights_4bit & 0x0F + + # Convert to signed 4-bit range [-8, 7] + weights_unpacked = torch.where( + weights_unpacked > 7, weights_unpacked - 16, weights_unpacked + ) + + # Dequantize weights using grouped quantization + actual_in_features = in_features * 2 + num_groups = actual_in_features // group_size + + # Reshape weights for grouped dequantization + weights_grouped = weights_unpacked.view(out_features, num_groups, group_size) + + # Expand scales and zeros to match grouped weights + scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size) + zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size) + + # Dequantize: (quantized - zero_point) * scale + dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded + dq_weights = dq_weights_grouped.view(out_features, actual_in_features) + + # Dequantize input (per-token) + # For per-token quantization, each token (row) has its own scale and zero_point + x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token( + x_quantized_2d, + input_scale, + input_zero_point, + -128, + 127, + torch.int8, + torch.float32, + ) + + # Perform linear operation + out = torch.nn.functional.linear(x_dequantized, dq_weights) + out_shape = original_x_shape[:-1] + (out_features,) + return out.reshape(out_shape) + + +name = "linear_qta8a_qga4w" +lib.define( + f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor" +) +lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd") +linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 19594002cf2..178cc9ea08b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -487,7 +487,12 @@ def register_int8_mm_op(features: OpFeatures): return features -@update_features(exir_ops.edge.et_vk.linear_weight_int4.default) +@update_features( + [ + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.et_vk.linear_qta8a_qga4w.default, + ] +) def register_int4_mm_op(features: OpFeatures): features.buffer_impl = True features.texture_impl = TextureImplFeatures( diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index 94072dfbfea..43e62eadeee 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -53,6 +53,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -71,68 +82,60 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); /* - * DEQUANTIZATION SHADER (BUFFER STORAGE) - * - * This shader converts n-bit integer tensor values back to floating-point representations - * using pre-computed quantization parameters (scale and zero_point). The dequantization - * reconstructs the original floating-point values from their discrete integer representations - * with minimal precision loss. - * - * ALGORITHM: - * 1. Load quantized integer value from buffer - * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale - * 3. Store reconstructed floating-point value to output buffer - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access - * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering - * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping - * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) - * - * DEQUANTIZATION FORMULA VISUALIZATION: - * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: - * - * Integer Domain: Floating Point Domain: - * quant_min ──────────────► min_val - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * quant_max ──────────────► max_val - * - * Dequantization Process: - * Input: -103 (int8) - * Step 1: qvalue - zero_point = -103 - (-128) = 25 - * Step 2: result * scale = 25 * 0.1 = 2.5 - * Output: 2.5 (float) - * - * PER-TENSOR DEQUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All elements use same dequantization parameters - * - Parameters passed as push constants for efficiency - * - Formula: value = (qvalue - zero_point) * scale - * - * PER-TOKEN DEQUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates its token_id from tensor coordinates - * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] - * - * Token ID calculation for element at tensor index (w, z, y, x): - * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y - * - 3D tensor: token_id = z * sizes.y + y - * - 2D tensor: token_id = y - * - 1D tensor: token_id = 0 - */ + Dequantization Shader (Buffer Storage) + This shader converts n-bit integer tensor values back to floating-point representations + using pre-computed quantization parameters (scale and zero_point). The dequantization + reconstructs the original floating-point values from their discrete integer representations + with minimal precision loss. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - dequantize_per_tensor + This mode reverses the uniform quantization applied across the entire tensor by using the + single scale and zero_point values to convert quantized integer values back to their original + floating-point representation. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_per_token + This mode reverses the quantization applied individually to each token (or element) in the + input by using separate scale and zero_point values for each token. For a tensor of shape + [B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting + quantized values back to their original floating-point representation for each group of H + elements independently. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_per_channel + This mode reverses the quantization applied separately to each channel of the input tensor + by using distinct scale and zero_point values for each channel. For a tensor of shape + [B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C + channels, converting quantized values back to their original floating-point representation + independently for each channel. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_block_wise + This mode reverses the block-wise quantization applied to groups of elements by using separate + scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the + inverse affine transformation per block to convert quantized values back to their original + floating-point representation. For example, if the tensor shape is [6, 9, 4] and + blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements, + and dequantization is performed independently on each block. + + (*) global_wg_size: default + (*) local_wg_size: default + + Dequantization Formula: + value = (qvalue - zero_point) * scale +*/ #ifdef per_tensor @@ -187,7 +190,7 @@ void dequantize_per_token() { t_out[out_bufi] = value; } -#else // per_channel +#elif defined(per_channel) void dequantize_per_channel() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -226,6 +229,29 @@ void dequantize_per_channel() { t_out[out_bufi] = value; } +#else // block_wise + +void dequantize_block_wise() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + const ivec4 bcoord = out_tidx / blockSize; + + const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + + t_out[out_bufi] = value; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index b9a53217452..999c59d3b79 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -19,3 +19,5 @@ dequantize_buffer: MODE: per_token - NAME: dequantize_per_channel_buffer MODE: per_channel + - NAME: dequantize_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index 5c978c61846..20bf6c87e26 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -56,6 +56,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -201,7 +212,7 @@ void dequantize_per_token() { write_texel(t_out, pos, outtex); } -#else // per_channel +#elif defined(per_channel) void dequantize_per_channel() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -292,6 +303,39 @@ void dequantize_per_channel() { write_texel(t_out, pos, outtex); } +#else // block_wise + +void dequantize_block_wise() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) + return; + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); + int foldedZ = pos.z; + + int C_total = numBlocks.z * blockSize.z; + + [[unroll]] for (int i = 0; i < 4; ++i) { + ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); + + ivec4 bcoord = tidx / blockSize; + int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 88ccc6e3274..9b624762192 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -19,3 +19,5 @@ dequantize_texture: MODE: per_token - NAME: dequantize_per_channel_texture3d MODE: per_channel + - NAME: dequantize_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl index 9834a539667..9a342d8e057 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -53,6 +53,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -71,64 +82,54 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); /* - * QUANTIZATION SHADER (BUFFER STORAGE) - * - * This shader converts floating-point tensor values to n-bit integer representations - * using pre-computed quantization parameters (scale and zero_point). The quantization - * maps floating-point values to a discrete integer range while preserving the - * original data distribution as much as possible. - * - * ALGORITHM: - * 1. Load floating-point input value from buffer - * 2. Apply quantization formula: qvalue = round(value / scale) + zero_point - * 3. Clamp result to [quant_min, quant_max] range - * 4. Store quantized integer value to output buffer - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access - * - and supports any tensor layout through stride calculations and dimension ordering - * - Per-Token Config: Assumes width-packed layout (packed_dim = 0) - * - since that is how token index is calculated - * - * QUANTIZATION FORMULA VISUALIZATION: - * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: - * - * Floating Point Domain: Integer Domain: - * min_val ────────────────► quant_min - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * max_val ────────────────► quant_max - * - * Quantization Process: - * Input: 2.5 (float) - * Step 1: value / scale = 2.5 / 0.1 = 25.0 - * Step 2: round(25.0) + zero_point = 25 + (-128) = -103 - * Step 3: clamp(-103, -128, 127) = -103 - * Output: -103 (int8) - * - * PER-TENSOR QUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All elements use same quantization parameters - * - Parameters passed as push constants for efficiency - * - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max) - * - * PER-TOKEN QUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates its token_id from tensor coordinates - * - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max) - */ + Quantization Shader (Buffer Storage) + This shader converts floating-point tensor values to n-bit integer representations + using pre-computed quantization parameters (scale and zero_point). The quantization + maps floating-point values to a discrete integer range while preserving the original + data distribution as much as possible. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - quantize_per_tensor + This mode applies uniform quantization across the entire tensor using a single scale + and zero_point value. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_token + This mode applies quantization individually to each token (or element) in the input, + using separate scale and zero_point values for each token. For instance if we have + a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_channel + This mode applies quantization separately to each channel of the input tensor, using + distinct scale and zero_point values for each channel. For example, if the tensor shape + is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing + each channel to be quantized independently. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_block_wise + This mode applies quantization in blocks or groups of elements, allowing different scale + and zero_point values for each block. It is equivalent to quantize_affine, where quantization + parameters are affine transformations applied per block. For example, if the tensor shape + is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. + + (*) global_wg_size: default + (*) local_wg_size: default + + Quantization Formula: + qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). +*/ #ifdef per_tensor @@ -183,7 +184,7 @@ void quantize_per_token() { t_out[out_bufi] = qvalue; } -#else // per_channel +#elif defined(per_channel) void quantize_per_channel() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -222,6 +223,29 @@ void quantize_per_channel() { t_out[out_bufi] = qvalue; } +#else // block_wise + +void quantize_block_wise() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + const ivec4 bcoord = out_tidx / blockSize; + + const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + const OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]); + + t_out[out_bufi] = qvalue; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 1dd8e6e2ffe..5b479c2f90f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -19,3 +19,5 @@ quantize_buffer: MODE: per_token - NAME: quantize_per_channel_buffer MODE: per_channel + - NAME: quantize_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl index 148fa85eb2b..69f219ef329 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -58,6 +58,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict BlockPC { + ivec4 blockSize; // WHCN + ivec4 numBlocks; // (#W,#H,#C,#N) + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -70,68 +81,58 @@ ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* - * QUANTIZATION SHADER (TEXTURE STORAGE) - * - * This shader converts floating-point tensor values to n-bit integer representations - * using pre-computed quantization parameters (scale and zero_point). The quantization - * maps floating-point values to a discrete integer range while preserving the - * original data distribution as much as possible. - * - * ALGORITHM: - * 1. Load floating-point texel (4 values) from 3D texture - * 2. Apply quantization formula to each component: qvalue = round(value / scale) + zero_point - * 3. Clamp each result to [quant_min, quant_max] range - * 4. Store quantized integer texel to output texture - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with texel-based processing - * - Assumes width-packed layout (packed_dim = 0) in current implementation - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - For per-token mode: scale/zero_point tensors must use buffer storage - * - * QUANTIZATION FORMULA VISUALIZATION: - * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: - * - * Floating Point Domain: Integer Domain: - * min_val ────────────────► quant_min - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * max_val ────────────────► quant_max - * - * Texel Quantization Process: - * Input Texel: [2.5, -1.0, 0.5, 3.2] (float4) - * Per-component quantization with scale=0.1, zero_point=-128: - * Component 0: round(2.5 / 0.1) + (-128) = 25 + (-128) = -103 - * Component 1: round(-1.0 / 0.1) + (-128) = -10 + (-128) = -138 → clamp to -128 - * Component 2: round(0.5 / 0.1) + (-128) = 5 + (-128) = -123 - * Component 3: round(3.2 / 0.1) + (-128) = 32 + (-128) = -96 - * Output Texel: [-103, -128, -123, -96] (int4) - * - * PER-TENSOR QUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All texel components use same quantization parameters - * - Parameters passed as push constants for efficiency - * - Each thread processes one texel (4 elements) independently - * - Formula: qvalue[i] = clamp(round(value[i] / scale) + zero_point, quant_min, quant_max) - * - * PER-TOKEN QUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates token_id from its 3D texture position - * - Scale/zero_point buffers accessed directly (not as textures) - * - Formula: qvalue[i] = clamp(round(value[i] / scale[token_id]) + zero_point[token_id], quant_min, quant_max) - */ + Quantization Shader (Texture Storage) + This shader converts floating-point tensor values to n-bit integer representations + using pre-computed quantization parameters (scale and zero_point). The quantization + maps floating-point values to a discrete integer range while preserving the original + data distribution as much as possible. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - quantize_per_tensor + This mode applies uniform quantization across the entire tensor using a single scale + and zero_point value. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_token + This mode applies quantization individually to each token (or element) in the input, + using separate scale and zero_point values for each token. For instance if we have + a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_channel + This mode applies quantization separately to each channel of the input tensor, using + distinct scale and zero_point values for each channel. For example, if the tensor shape + is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing + each channel to be quantized independently. + + (*) global_wg_size: default + (*) local_wg_size: Default with special handling for batch dimension. When quantizing along + the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, + uses standard workgroup size derived from global workgroup dimensions. + + - quantize_block_wise + This mode applies quantization in blocks or groups of elements, allowing different scale + and zero_point values for each block. It is equivalent to quantize_affine, where quantization + parameters are affine transformations applied per block. For example, if the tensor shape + is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. + + (*) global_wg_size: default + (*) local_wg_size: Default with special handling for batch dimension. When quantizing along + the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, + uses standard workgroup size derived from global workgroup dimensions. + + Quantization Formula: + qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). +*/ #ifdef per_tensor @@ -192,7 +193,7 @@ void quantize_per_token() { write_texel(t_out, pos, outtex); } -#else // per_channel +#elif defined(per_channel) void quantize_per_channel() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -270,6 +271,36 @@ void quantize_per_channel() { write_texel(t_out, pos, outtex); } +#else // block_wise + +void quantize_block_wise() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) + return; + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); + int foldedZ = pos.z; + + int C_total = numBlocks.z * blockSize.z; + + [[unroll]] for (int i = 0; i < 4; ++i) { + ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); + + ivec4 bcoord = tidx / blockSize; + int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]); + outtex[i] = qvalue; + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 47e532be8b9..2e40ac90794 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -19,3 +19,5 @@ quantize_texture: MODE: per_token - NAME: quantize_per_channel_texture3d MODE: per_channel + - NAME: quantize_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 7edb9b2f70d..61fd76145a4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -17,38 +17,59 @@ namespace vkcompute { -void resize_dequantize_output( +void resize_dequantize_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { (void)extra_args; - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(in)); + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + out->virtual_resize(in->sizes()); } -utils::uvec3 dequantize_per_channel_global_wg_size( +utils::uvec3 dequantize_per_channel_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { + (void)args; (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + const ValueRef input = args.at(1).refs.at(0); - return global_wg_size; + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. We need to ensure that we dispatch the correct + // number of workgroups in the Z dimension to cover all batch-channel + // combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { + local_wg_size[2] = 1; + } + + return local_wg_size; } -utils::uvec3 dequantize_per_channel_local_wg_size( +utils::uvec3 dequantize_block_wise_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - (void)args; + (void)shader; (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); utils::uvec3 local_wg_size = @@ -56,16 +77,17 @@ utils::uvec3 dequantize_per_channel_local_wg_size( // WORKAROUND: The CommandBuffer::dispatch function divides // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel dequantization along the batch - // axis, we need to ensure that we dispatch the correct number of workgroups - // in the Z dimension to cover all batch-channel combinations. + // workgroups to dispatch. We need to ensure that we dispatch the correct + // number of workgroups in the Z dimension to cover all batch-channel + // combinations. // // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], // local_wg_size[2]) might reduce the number of workgroups dispatched. To // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, // we set local_wg_size[2] = 1. const auto input_sizes = graph->sizes_of(input); - if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { local_wg_size[2] = 1; } @@ -131,7 +153,7 @@ void add_dequantize_per_tensor_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); } void add_dequantize_per_token_node( @@ -161,25 +183,18 @@ void add_dequantize_per_token_node( graph.sizes_ubo(input), graph.strides_ubo(input), graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.strides_ubo(output)}; } else { param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; } + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -203,7 +218,7 @@ void add_dequantize_per_token_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); } void add_dequantize_per_channel_node( @@ -252,27 +267,19 @@ void add_dequantize_per_channel_node( graph.sizes_ubo(input), graph.strides_ubo(input), graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.strides_ubo(output)}; } else { param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; } + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -281,7 +288,7 @@ void add_dequantize_per_channel_node( graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - dequantize_per_channel_global_wg_size, + default_pick_global_wg_size, dequantize_per_channel_local_wg_size, // Inputs and Outputs {{output, vkapi::kWrite}, @@ -296,7 +303,94 @@ void add_dequantize_per_channel_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); +} + +void add_dequantize_block_wise_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& block_size, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // Convert dimensions to WHCN order for shader + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) + utils::ivec4 num_blocks_vec = { + tensor_size_whcn[0] / block_size_vec[0], + tensor_size_whcn[1] / block_size_vec[1], + tensor_size_whcn[2] / block_size_vec[2], + tensor_size_whcn[3] / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + dequantize_block_wise_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_node)); } void dequantize_per_tensor_impl( @@ -308,31 +402,39 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is an integer type VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } add_dequantize_per_tensor_node( graph, input, scale, zero_point, quant_min, quant_max, output); @@ -347,12 +449,11 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; @@ -366,15 +467,8 @@ void dequantize_per_token_impl( VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); - // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -430,12 +524,11 @@ void dequantize_per_channel_impl( const ValueRef axis = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; @@ -449,15 +542,8 @@ void dequantize_per_channel_impl( VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); - // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -513,8 +599,7 @@ void dequantize_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef block_size = args[arg_idx++]; const ValueRef scale = args[arg_idx++]; const ValueRef zero_point = args[arg_idx++]; const ValueRef input_dtype = args[arg_idx++]; @@ -529,33 +614,61 @@ void dequantize_affine_impl( // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is an integer type VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization + // Verify block_size is valid (each dimension must divide evenly into input + // size) const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + if ((*block_size_list)[i] > 1) { + VK_CHECK_COND( + input_sizes[i] % (*block_size_list)[i] == 0, + "Input size at dimension ", + i, + " (", + input_sizes[i], + ") must be divisible by block_size at dimension ", + i, + " (", + (*block_size_list)[i], + ")"); + } } - // Default to per-tensor dequantization for TorchAO affine ops - add_dequantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); + add_dequantize_block_wise_node( + graph, + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + output); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index d786981e1fc..92719505a0f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -17,40 +17,60 @@ namespace vkcompute { -void resize_quantize_output( +void resize_quantize_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { (void)extra_args; - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(in)); + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + out->virtual_resize(in->sizes()); } -utils::uvec3 quantize_per_channel_global_wg_size( +utils::uvec3 quantize_per_channel_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { (void)shader; + (void)args; (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + const ValueRef input = args.at(1).refs.at(0); + + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel quantization along the batch axis, + // we need to ensure that we dispatch the correct number of workgroups in the + // Z dimension to cover all batch-channel combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { + local_wg_size[2] = 1; + } - return global_wg_size; + return local_wg_size; } -utils::uvec3 quantize_per_channel_local_wg_size( +utils::uvec3 quantize_block_wise_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { (void)shader; - (void)args; (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); utils::uvec3 local_wg_size = @@ -67,7 +87,8 @@ utils::uvec3 quantize_per_channel_local_wg_size( // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, // we set local_wg_size[2] = 1. const auto input_sizes = graph->sizes_of(input); - if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { local_wg_size[2] = 1; } @@ -133,7 +154,7 @@ void add_quantize_per_tensor_node( // Resize Args {}, // Resizing Logic - resize_quantize_output)); + resize_quantize_node)); } void add_quantize_per_token_node( @@ -205,7 +226,7 @@ void add_quantize_per_token_node( // Resize Args {}, // Resizing Logic - resize_quantize_output)); + resize_quantize_node)); } void add_quantize_per_channel_node( @@ -283,7 +304,7 @@ void add_quantize_per_channel_node( graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - quantize_per_channel_global_wg_size, + default_pick_global_wg_size, quantize_per_channel_local_wg_size, // Inputs and Outputs {{output, vkapi::kWrite}, @@ -298,7 +319,94 @@ void add_quantize_per_channel_node( // Resize Args {}, // Resizing Logic - resize_quantize_output)); + resize_quantize_node)); +} + +void add_quantize_block_wise_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& block_size, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // Convert PyTorch dimensions to WHCN order for shader + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) + utils::ivec4 num_blocks_vec = { + tensor_size_whcn[0] / block_size_vec[0], + tensor_size_whcn[1] / block_size_vec[1], + tensor_size_whcn[2] / block_size_vec[2], + tensor_size_whcn[3] / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + quantize_block_wise_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_node)); } void quantize_per_tensor_impl( @@ -310,7 +418,7 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warning - dtype is inferred from output @@ -339,7 +447,7 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warning - dtype is inferred from output @@ -412,7 +520,7 @@ void quantize_per_channel_impl( const ValueRef axis = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warning - dtype is inferred from output @@ -485,8 +593,7 @@ void quantize_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef block_size = args[arg_idx++]; const ValueRef scale = args[arg_idx++]; const ValueRef zero_point = args[arg_idx++]; const ValueRef output_dtype = args[arg_idx++]; @@ -499,6 +606,8 @@ void quantize_affine_impl( // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is a floating point type @@ -507,18 +616,51 @@ void quantize_affine_impl( graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Verify block_size is valid (each dimension must divide evenly into input + // size) const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + if ((*block_size_list)[i] > 1) { + VK_CHECK_COND( + input_sizes[i] % (*block_size_list)[i] == 0, + "Input size at dimension ", + i, + " (", + input_sizes[i], + ") must be divisible by block_size at dimension ", + i, + " (", + (*block_size_list)[i], + ")"); + } } - // Default to per-tensor quantization for TorchAO affine ops - add_quantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); + add_quantize_block_wise_node( + graph, + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + output); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp index a47c58b7ef6..728d38c3e2d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp @@ -68,11 +68,10 @@ void check_linear_qta8a_qga4w_args( const auto mat1_scale_sizes = graph.sizes_of(mat1_scale); const auto mat1_zero_point_sizes = graph.sizes_of(mat1_zero_point); - VK_CHECK_COND(mat1_scale_sizes.size() == 1); - VK_CHECK_COND(mat1_zero_point_sizes.size() == 1); - - VK_CHECK_COND(mat1_scale_sizes[0] == input_num_tokens); - VK_CHECK_COND(mat1_zero_point_sizes[0] == input_num_tokens); + VK_CHECK_COND( + utils::val_at(-1, mat1_scale_sizes) == input_num_tokens); + VK_CHECK_COND( + utils::val_at(-1, mat1_zero_point_sizes) == input_num_tokens); // Verify weight scales and zeros have the same shape const auto weight_scales_sizes = graph.sizes_of(weight_scales); diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 7f535a0001b..ef429ff21fa 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -35,6 +35,7 @@ python_unittest( "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/vulkan:vulkan_preprocess", + "//pytorch/ao:torchao", # @manual ] ) diff --git a/backends/vulkan/test/op_tests/quantize_affine_test.cpp b/backends/vulkan/test/op_tests/quantize_affine_test.cpp new file mode 100644 index 00000000000..8a54774d703 --- /dev/null +++ b/backends/vulkan/test/op_tests/quantize_affine_test.cpp @@ -0,0 +1,859 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include "test_utils.h" + +#include +#include +#include + +static inline void +_check_dims(c10::string_view name, int64_t expected, int64_t actual) { + VK_CHECK_COND( + expected == actual, + name, + " has rank ", + actual, + " but block_size has length ", + expected); +} + +at::Tensor quantize_affine_reference_impl( + const at::Tensor& input_, + const std::vector& block_size, + const at::Tensor& scale, + const c10::optional& zero_point_opt, + int64_t quant_min, + int64_t quant_max, + at::ScalarType out_dtype, + c10::optional zero_point_domain_opt = std::string("INT")) { + constexpr float kEps = 1e-7f; + + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kFloat || input_.scalar_type() == at::kHalf || + input_.scalar_type() == at::kBFloat16, + "Unsupported input dtype: ", + input_.dtype()); + + auto zero_point_domain = + zero_point_domain_opt.has_value() ? *zero_point_domain_opt : "INT"; + + bool has_zp = zero_point_opt.has_value(); + VK_CHECK_COND( + has_zp || zero_point_domain == "NONE" || zero_point_domain == "", + "zero_point must be supplied unless zero_point_domain is NONE or null"); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor scale_b = + scale.view(shape_after_reduction).to(input_reshaped.scalar_type()); + + at::Tensor zp_b; + if (has_zp) { + zp_b = (*zero_point_opt).view(shape_after_reduction).toType(at::kFloat); + } + + scale_b = scale_b.clamp_min(kEps); + at::Tensor inv_scale = 1.0f / scale_b; + + at::Tensor q; + if (zero_point_domain == "INT") { + VK_CHECK_COND(has_zp, "INT zero_point_domain requires zero_point tensor"); + q = at::round(input_reshaped * inv_scale) + zp_b; + } else if (zero_point_domain == "NONE" || zero_point_domain.empty()) { + VK_CHECK_COND( + !has_zp, "zero_point must be None when domain is NONE / null"); + q = at::round(input_reshaped * inv_scale); + } else { + VK_CHECK_COND( + has_zp && zero_point_domain == "FLOAT", + "zero_point_domain must be INT, FLOAT, NONE or null"); + const float mid_point = (quant_max + quant_min + 1) * 0.5f; + at::Tensor min_val = zp_b - scale_b * mid_point; + q = at::round((input_reshaped - min_val) / scale_b); + } + + q = at::clamp(q, (double)quant_min, (double)quant_max); + + q = q.view(in_sizes).to(out_dtype); + + return q; +} + +at::Tensor dequantize_affine_reference_impl( + const at::Tensor& input_, + const std::vector& block_size, + const at::Tensor& scale, + const c10::optional& zero_point_opt, + int64_t quant_min, + int64_t quant_max, + at::ScalarType out_dtype, + c10::optional zero_point_domain_opt = std::string("INT")) { + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kByte || input_.scalar_type() == at::kChar || + input_.scalar_type() == at::kShort || + input_.scalar_type() == at::kInt, + "Unsupported input dtype: ", + input_.dtype()); + + VK_CHECK_COND( + out_dtype == at::kFloat || out_dtype == at::kHalf || + out_dtype == at::kBFloat16, + "Unsupported output dtype: ", + out_dtype); + + auto zero_point_domain = + zero_point_domain_opt.has_value() ? *zero_point_domain_opt : "INT"; + + bool has_zp = zero_point_opt.has_value(); + VK_CHECK_COND( + has_zp || zero_point_domain == "NONE" || zero_point_domain == "", + "zero_point must be supplied unless zero_point_domain is NONE or null"); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor scale_b = scale.view(shape_after_reduction).to(out_dtype); + + at::Tensor zp_b; + if (has_zp) { + zp_b = (*zero_point_opt).view(shape_after_reduction).to(out_dtype); + } + + at::Tensor input_fp = input_reshaped.to(out_dtype); + at::Tensor dq; + + if (zero_point_domain == "INT") { + VK_CHECK_COND(has_zp, "INT zero_point_domain requires zero_point tensor"); + dq = (input_fp - zp_b) * scale_b; + } else if (zero_point_domain == "NONE" || zero_point_domain.empty()) { + VK_CHECK_COND( + !has_zp, "zero_point must be None when domain is NONE / null"); + dq = input_fp * scale_b; + } else { + VK_CHECK_COND( + has_zp && zero_point_domain == "FLOAT", + "zero_point_domain must be INT, FLOAT, NONE or null"); + const float mid_point = (quant_max + quant_min + 1) * 0.5f; + at::Tensor min_val = zp_b - scale_b * mid_point; + dq = input_fp * scale_b + min_val; + } + + dq = dq.view(in_sizes); + + return dq; +} + +// Wrapper function to maintain compatibility with existing test code (above is +// a good reference for how the python implementation works) +at::Tensor quantize_affine_reference_impl( + const at::Tensor& input, + const std::vector& block_size, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + return quantize_affine_reference_impl( + input, + block_size, + scale, + c10::optional(zero_point), + quant_min, + quant_max, + dtype, + std::string("INT")); +} + +// Wrapper function for dequantize_affine +at::Tensor dequantize_affine_reference_impl( + const at::Tensor& input, + const std::vector& block_size, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + return dequantize_affine_reference_impl( + input, + block_size, + scale, + c10::optional(zero_point), + quant_min, + quant_max, + dtype, + std::string("INT")); +} + +void test_vulkan_quantize_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Get reference output + at::Tensor reference_out = quantize_affine_reference_impl( + input, + block_size, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(dtype)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("torchao.quantize_affine.default") + (graph, + { + r_input.value, + r_block_size, + r_scale.value, + r_zero_point.value, + r_output_dtype, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // Copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + if (!output_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " scales: ["; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << scales[i] << (i < scales.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " zero_points: ["; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << zero_points[i] << (i < zero_points.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl << input << std::endl; + std::cout << "reference:" << std::endl << reference_int << std::endl; + std::cout << "vulkan:" << std::endl << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_quantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +TEST(VulkanQuantizeAffineTest, test_1d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 1D: 1x1x1x12 Tensor, block_size is 3 + test_vulkan_quantize_affine( + {12}, // input_sizes + {3}, // block_size + {0.1f, 0.2f, 0.15f, 0.25f}, // scales (4 blocks) + {10, -20, 5, 30}, // zero_points (4 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizeAffineTest, test_2d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 2D: 1x1x8x6 Tensor, block_size is 1x1x2x3 (8/2=4, 6/3=2, so 4*2=8 blocks) + test_vulkan_quantize_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size (1/1=1, 1/1=1, 8/2=4, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f}, // scales (8 blocks) + {-10, 15, 0, 25, -5, 20, 10, -15}, // zero_points (8 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizeAffineTest, test_3d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 3D: 1x6x4x6 Tensor, block_size is 3x2x2 (6/3=2, 4/2=2, 6/2=3, so 2*2*3=12 + // blocks) + test_vulkan_quantize_affine( + {6, 4, 6}, // input_sizes (changed 7->6 so divisible by 3) + {3, + 2, + 2}, // block_size (6 divisible by 3, 4 divisible by 2, 6 divisible by 2) + {0.1f, + 0.2f, + 0.15f, + 0.25f, + 0.3f, + 0.05f, + 0.4f, + 0.35f, + 0.12f, + 0.18f, + 0.22f, + 0.28f}, // scales (12 blocks) + {-15, 10, 5, -25, 20, -10, 15, -5, 8, -12, 18, -8}, // zero_points (12 + // blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizeAffineTest, test_4d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 4D: 8x6x6x6 Tensor, block_size is 2x3x2x3 (8/2=4, 6/3=2, 6/2=3, 6/3=2, so + // 4*2*3*2=48 blocks) + test_vulkan_quantize_affine( + {8, 6, 6, 6}, // input_sizes + {2, 3, 2, 3}, // block_size (8/2=4, 6/3=2, 6/2=3, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f, 0.12f, 0.18f, + 0.22f, 0.28f, 0.32f, 0.08f, 0.45f, 0.38f, 0.14f, 0.24f, 0.16f, 0.26f, + 0.34f, 0.06f, 0.44f, 0.36f, 0.11f, 0.21f, 0.13f, 0.23f, 0.31f, 0.07f, + 0.41f, 0.37f, 0.19f, 0.29f, 0.17f, 0.27f, 0.33f, 0.09f, 0.43f, 0.39f, + 0.10f, 0.20f, 0.14f, 0.24f, 0.30f, 0.04f, 0.40f, 0.34f}, // scales (48 + // blocks) + {-20, 10, 5, -15, 25, -10, 15, -5, 8, -12, 18, -8, 22, + -18, 12, -22, -25, 15, 0, -20, 30, -5, 20, -10, 5, -25, + 10, -15, 35, -15, 25, -35, -30, 20, -5, -25, 40, 0, 30, + -40, 10, -30, 15, -10, 45, -20, 35, -45}, // zero_points (48 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +void test_vulkan_dequantize_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kChar, + at::ScalarType out_dtype = at::kFloat, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + // Create input tensor with random integer values within quant_min and + // quant_max + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = at::randint( + quant_min, + quant_max + 1, + input_sizes_int64, + at::device(at::kCPU).dtype(in_dtype)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Get reference output + at::Tensor reference_out = dequantize_affine_reference_impl( + input, + block_size, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + out_dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Create block_size as IntList instead of Tensor + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + // Create input_dtype scalar + const ValueRef r_input_dtype = + graph.add_scalar(static_cast(in_dtype)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(out_dtype)); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + // Match the argument order in dequantize_affine_impl in Dequantize.cpp: + // input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, + // output_dtype, output + VK_GET_OP_FN("torchao.dequantize_affine.default") + (graph, + { + r_input.value, + r_block_size, + r_scale.value, + r_zero_point.value, + r_input_dtype, + r_quant_min, + r_quant_max, + r_output_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // Copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + if (!output_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " scales: ["; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << scales[i] << (i < scales.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " zero_points: ["; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << zero_points[i] << (i < zero_points.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl << input << std::endl; + std::cout << "reference:" << std::endl << reference_out << std::endl; + std::cout << "vulkan:" << std::endl << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kChar, + at::ScalarType out_dtype = at::kFloat) { + // Test with buffer storage + test_vulkan_dequantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +TEST(VulkanDequantizeAffineTest, test_1d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 1D: 1x1x1x12 Tensor, block_size is 3 + test_vulkan_dequantize_affine( + {12}, // input_sizes + {3}, // block_size + {0.1f, 0.2f, 0.15f, 0.25f}, // scales (4 blocks) + {10, -20, 5, 30}, // zero_points (4 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_2d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 2D: 1x1x8x6 Tensor, block_size is 1x1x2x3 (8/2=4, 6/3=2, so 4*2=8 blocks) + test_vulkan_dequantize_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size (1/1=1, 1/1=1, 8/2=4, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f}, // scales (8 blocks) + {-10, 15, 0, 25, -5, 20, 10, -15}, // zero_points (8 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_3d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 3D: 1x6x4x6 Tensor, block_size is 3x2x2 (6/3=2, 4/2=2, 6/2=3, so 2*2*3=12 + // blocks) + test_vulkan_dequantize_affine( + {6, 4, 6}, // input_sizes (changed 7->6 so divisible by 3) + {3, + 2, + 2}, // block_size (6 divisible by 3, 4 divisible by 2, 6 divisible by 2) + {0.1f, + 0.2f, + 0.15f, + 0.25f, + 0.3f, + 0.05f, + 0.4f, + 0.35f, + 0.12f, + 0.18f, + 0.22f, + 0.28f}, // scales (12 blocks) + {-15, 10, 5, -25, 20, -10, 15, -5, 8, -12, 18, -8}, // zero_points (12 + // blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_4d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 4D: 8x6x6x6 Tensor, block_size is 2x3x2x3 (8/2=4, 6/3=2, 6/2=3, 6/3=2, so + // 4*2*3*2=48 blocks) + test_vulkan_dequantize_affine( + {8, 6, 6, 6}, // input_sizes + {2, 3, 2, 3}, // block_size (8/2=4, 6/3=2, 6/2=3, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f, 0.12f, 0.18f, + 0.22f, 0.28f, 0.32f, 0.08f, 0.45f, 0.38f, 0.14f, 0.24f, 0.16f, 0.26f, + 0.34f, 0.06f, 0.44f, 0.36f, 0.11f, 0.21f, 0.13f, 0.23f, 0.31f, 0.07f, + 0.41f, 0.37f, 0.19f, 0.29f, 0.17f, 0.27f, 0.33f, 0.09f, 0.43f, 0.39f, + 0.10f, 0.20f, 0.14f, 0.24f, 0.30f, 0.04f, 0.40f, 0.34f}, // scales (48 + // blocks) + {-20, 10, 5, -15, 25, -10, 15, -5, 8, -12, 18, -8, 22, + -18, 12, -22, -25, 15, 0, -20, 30, -5, 20, -10, 5, -25, + 10, -15, 35, -15, 25, -35, -30, 20, -5, -25, 40, 0, 30, + -40, 10, -30, 15, -10, 45, -20, 35, -45}, // zero_points (48 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 9eac90ac33d..b9386f92772 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -216,3 +216,9 @@ def define_common_targets(is_fbcode = False): ":test_utils", ] ) + define_test_targets( + "quantize_affine_test", + extra_deps = [ + ":test_utils", + ] + ) diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 0de1fd5d69a..4f54bc638ba 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -16,6 +16,7 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightQuantizer from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import Quantizer @@ -153,3 +154,56 @@ def test_fuse_linear_qcs4w(self): self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + + def test_fuse_linear_qta8a_qga4w(self): + """Test fusion of dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).""" + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + # Use source transform quantizer for dynamic activation + grouped weight quantization + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=128, # Group size for 4-bit weights + padding_allowed=False, + precision=torch.float32, + scales_precision=torch.float32, + device=torch.device("cpu"), + ) + + # Apply source transform quantization + quantized_model = quantizer.quantize(model) + + # Export the quantized model + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, + _check_ir_validity=False, + ) + + program = torch.export.export_for_training( + quantized_model, sample_inputs, strict=True + ).module() + + program = torch.export.export(program, sample_inputs) + + edge_manager = to_edge( + program, + compile_config=edge_compile_config, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + # Check that the linear_qta8a_qga4w operator was created + self.assertEqual(op_node_count(gm, "linear_qta8a_qga4w.default"), 1) + # Check that the original quantization/dequantization nodes were removed + self.assertEqual(op_node_count(gm, "quantize_per_token.default"), 0) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + self.assertEqual(op_node_count(gm, "linear.default"), 0) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index d71c0a35776..9086b2d0792 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -38,6 +38,14 @@ "dequantize_affine.default", } +_Q_OPS = { + "quantize_per_tensor.tensor", + "quantize_per_tensor.default", + "quantize_per_channel.default", + "quantize_per_token.default", + "quantize_affine.default", +} + ## ## Node type determination ## @@ -50,6 +58,13 @@ def is_dequant_node(node: torch.fx.Node) -> bool: return node_name in _DQ_OPS +def is_quant_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name in _Q_OPS + + def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False From 328836f41ce727cce0c0e915e01b9931b02d296d Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 30 Jul 2025 13:17:47 -0400 Subject: [PATCH 4/4] [ET-VK][Ops] torchao.choose_qparams_affine vulkan impl and shader (buffer only) and cleanup (#13003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Changes * Implement `torchao.choose_qparams_affine` operator in Vulkan backend with comprehensive buffer storage support * Add block-wise quantization parameter computation in `choose_qparams_buffer.glsl` shader for configurable tensor block analysis * Extend quantization parameter infrastructure in `ChooseQParams.cpp` to handle affine transformations with configurable block sizes and multiple mapping types * Support three quantization mapping strategies: ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR for optimal parameter selection * Consolidated the logic for choosing scale and zero point between affine cases and regular quantized_decomposed cases. BE: Improved the documentation in the shader logic which is more detailed and clear # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.choose_qparams_affine` operator, which is essential for computing optimal quantization parameters in dynamic quantization workflows. The `choose_qparams_affine` operator provides flexible block-wise parameter computation that analyzes statistical distributions within tensor blocks, enabling: * **Block-wise Parameter Computation**: Analyzes configurable tensor blocks to compute optimal scale and zero-point values, improving quantization accuracy for heterogeneous data distributions * **Multiple Mapping Types**: Supports ASYMMETRIC, SYMMETRIC, and SYMMETRIC_NO_CLIPPING_ERR quantization strategies for different precision-performance trade-offs # Operator Description The `choose_qparams_affine` operator computes optimal quantization parameters (scale and zero_point) from floating-point tensor blocks using statistical analysis of data distributions. Block-wise parameter computation divides tensors into blocks and analyzes each block independently to determine the best quantization mapping for subsequent quantization operations. The parameter calculation varies by mapping type: - **ASYMMETRIC**: `scale = (max - min) / (quant_max - quant_min)`, `zero_point = quant_min - round(min / scale)` - **SYMMETRIC**: `scale = max_abs / ((quant_max - quant_min) / 2)`, `zero_point = midpoint` - **SYMMETRIC_NO_CLIPPING_ERR**: `scale = max(abs(min)/abs(quant_min), max/quant_max)`, `zero_point = midpoint` **Storage Requirements**: Input tensors must be floating-point (kFloat) with width-packed layout. Output scale/zero_point tensors use buffer storage. NOTE: Texture storage implementation is not supported due to complexity of block-wise coordinate mapping in 3D texture space. This will likely be necessary for better efficiency in the future. # Block-wise Parameter Computation Implementation Block-wise parameter computation enables fine-grained quantization analysis by dividing tensors into blocks and computing separate scale/zero_point parameters for each block. The implementation uses several key data structures computed in `ChooseQParams.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `ceil(tensor_size_whcn / block_size_vec)` to handle non-divisible dimensions * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation * **`mapping_type`**: Integer encoding of quantization strategy (0=ASYMMETRIC, 1=SYMMETRIC, 2=SYMMETRIC_NO_CLIPPING_ERR) The block coordinate calculation uses: `block_coord = block_id_to_coord(block_id)` which converts linear block IDs back to 4D WHCN coordinates, then computes element ranges: `t0 = block_coord * blockSize` and `tEnd = t0 + blockSize` for nested loop iteration. # Shader Algorithm Overview ## Buffer Storage Implementation (`choose_qparams_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: `{nBlocks, 1u, 1u}` where `nBlocks = total number of blocks` computed from `ceil(tensor_size / block_size)` for each dimension - **Local WG Size**: `{1u, 1u, 1u}` (single thread per block for simplicity, though could be optimized for larger blocks) **Block-wise Mode Algorithm**: The shader uses a sophisticated multi-level nested approach to process tensor blocks efficiently. Each thread is assigned multiple blocks using strided access: `for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE)` where `STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x`. For each assigned block, the algorithm performs several key steps: **1. Block Coordinate Conversion**: The `block_id_to_coord(block_id)` function converts linear block IDs to 4D WHCN coordinates using modular arithmetic. **2. Element Range Calculation**: Computes the inclusive start coordinate `t0 = bc * blockSize` and exclusive end coordinate `tEnd = t0 + blockSize` to define the block's element boundaries in tensor space. **3. Nested Loop Min/Max Scan**: Uses four nested loops to iterate through all elements within the block: `for (int n = t0.w; n < tEnd.w; ++n) for (int c = t0.z; c < tEnd.z; ++c) for (int h = t0.y; h < tEnd.y; ++h) for (int w = t0.x; w < tEnd.x; ++w)` Each element is accessed using `tidx_to_bufi(ivec4(w,h,c,n), t_in_strides)` to convert 4D tensor coordinates to linear buffer indices with proper stride handling. **4. Parameter Calculation**: Calls `calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp)` which implements the three mapping strategies: * **ASYMMETRIC (mapping_type=0)**: Maps full range [min, max] to [quant_min, quant_max] preserving data distribution * **SYMMETRIC (mapping_type=1)**: Centers around zero using `max_abs = max(abs(min), abs(max))` for balanced quantization * **SYMMETRIC_NO_CLIPPING_ERR (mapping_type=2)**: Computes separate scales for positive/negative ranges and uses the maximum to prevent clipping **Future Improvements**: Implement workgroup-level reduction for large blocks, optimize memory access patterns for better cache utilization, and explore texture storage implementation with simplified block alignment constraints. Differential Revision: [D78436638](https://our.internmc.facebook.com/intern/diff/D78436638/) cc SS-JIA manuelcandales cbilgin [ghstack-poisoned] --- backends/vulkan/op_registry.py | 59 +- .../graph/ops/glsl/choose_qparams.glslh | 100 ++-- .../graph/ops/glsl/choose_qparams_buffer.glsl | 359 ++++++++---- .../graph/ops/glsl/choose_qparams_buffer.yaml | 2 + .../ops/glsl/choose_qparams_texture.glsl | 280 +++++++--- .../ops/glsl/choose_qparams_texture.yaml | 2 + .../runtime/graph/ops/impl/ChooseQParams.cpp | 289 ++++++---- .../test/op_tests/quantize_affine_test.cpp | 520 ++++++++++++++++++ 8 files changed, 1239 insertions(+), 372 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 178cc9ea08b..33ed3150535 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -245,9 +245,9 @@ def register_ephemeral_op(features: OpFeatures): @update_features( [ - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, @@ -276,14 +276,32 @@ def register_quantization_op(features: OpFeatures): [ exir_ops.edge.torchao.quantize_affine.default, exir_ops.edge.torchao.dequantize_affine.default, + ] +) +def register_affine_quantization_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_axis_map=False, + valid_packed_dims={PackedDim.WIDTH}, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED + features.handles_own_prepacking = True + + return features + + +@update_features( + [ exir_ops.edge.torchao.choose_qparams_affine.default, ] ) -def register_torchao_quantization_op(features: OpFeatures): - # TorchAO quantization operators - default to per-tensor behavior - # Same features as standard quantization ops +def register_choose_qparams_affine_op(features: OpFeatures): + # Currently only created a rudimentary buffer implementation for choose_qparams_affine + # since the reduction logic for blocks in texture3d is not trivial to implement in vulkan. features.texture_impl = TextureImplFeatures( - uses_axis_map=True, + uses_axis_map=False, valid_packed_dims={ PackedDim.WIDTH, }, @@ -292,37 +310,6 @@ def register_torchao_quantization_op(features: OpFeatures): features.resize_fn = True features.optimal_storage = VkStorageType.BUFFER - def check_torchao_quantization_node(node: torch.fx.Node) -> bool: - # Only per-tensor quantization is supported by the Vulkan backend. - if len(node.args) < 2: - return False - - block_size = node.args[1] - - if not isinstance(block_size, (list, tuple)): - return False - - input_arg = node.args[0] - if not isinstance(input_arg, torch.fx.Node): - return False - - input_tensor = input_arg.meta.get("val", None) - if not isinstance(input_tensor, FakeTensor): - return False - - input_shape = list(input_tensor.shape) - - if len(block_size) != len(input_shape): - return False - - # Check if block_size matches input_shape exactly (per-tensor quantization) - for i in range(len(block_size)): - if block_size[i] != input_shape[i]: - return False - - return True - - features.check_node_fn = check_torchao_quantization_node return features diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index d6d27d2e3a3..cfe5baa9c1d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,59 +9,67 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// Calculate scale and zero point from min and max values -void calculate_scale_and_zero_point( - float min_val, - float max_val, - int qmin, - int qmax, - float eps_threshold, - out float scale_val, - out int zero_point_val) { - // ensure we have zero included in our range - min_val = min(min_val, 0.0); - max_val = max(max_val, 0.0); +// mapping_type : 0 = ASYM, 1 = SYM, 2 = SYM_NO_CLIP +void calc_scale_zp( + float lo, float hi, + int qmin, int qmax, + int mapping_type, + float eps, + out float scale, out int zp) { + // Handle case where lo and hi are +/-INF (no valid values found) + if (isinf(lo) || isinf(hi)) { + lo = 0.0; + hi = 0.0; + } - scale_val = (max_val - min_val) / float(qmax - qmin); + float minv = min(lo, 0.0); + float maxv = max(hi, 0.0); - // Handle zero or very small scale - if (scale_val == 0.0 || isinf(1.0 / scale_val)) { - scale_val = 0.1; - } + if (mapping_type == 0) { // asymmetric + scale = (maxv - minv) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale == 0.0 || isinf(1.0/scale)) { + scale = eps; + } - // Cut off small scale using the provided eps threshold - if (scale_val < eps_threshold) { - float org_scale = scale_val; - scale_val = eps_threshold; + if (scale < eps) { + float org_scale = scale; + scale = eps; - // Adjust min and max based on new scale - if (min_val == 0.0) { - max_val = eps_threshold * float(qmax - qmin); - } else if (max_val == 0.0) { - min_val = -eps_threshold * float(qmax - qmin); - } else { - float amplifier = eps_threshold / org_scale; - min_val *= amplifier; - max_val *= amplifier; + // Adjust min and max based on new scale to maintain proper quantization range + if (minv == 0.0) { + maxv = eps * float(qmax - qmin); + } else if (maxv == 0.0) { + minv = -eps * float(qmax - qmin); + } else { + float amplifier = eps / org_scale; + minv *= amplifier; + maxv *= amplifier; + } + } + + // Calculate zero_point (matching reference implementation) + float initial_zero_point = float(qmin) - round(minv / scale); + zp = int(clamp(initial_zero_point, float(qmin), float(qmax))); + } else { // symmetric -- centred + float scale_sym; + if (mapping_type == 1) { // SYM + float M = max(abs(minv), abs(maxv)); + scale_sym = M / (float(qmax - qmin) * 0.5); + } else { // SYM_NO_CLIP + float smin = abs(minv) / max(abs(float(qmin)), 1.0); // Avoid division by zero + float smax = maxv / max(float(qmax), 1.0); // Avoid division by zero + scale_sym = max(smin, smax); } - } - // Calculate zero point - float zero_point_from_min = float(qmin) - min_val / scale_val; - float zero_point_from_max = float(qmax) - max_val / scale_val; - float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); - float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); - float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; + // Handle zero or very small scale + if (scale_sym == 0.0 || isinf(1.0/scale_sym)) { + scale_sym = eps; + } - // Nudge zero point to integer - if (initial_zero_point < float(qmin)) { - zero_point_val = qmin; - } else if (initial_zero_point > float(qmax)) { - zero_point_val = qmax; - } else { - zero_point_val = int(round(initial_zero_point)); + scale = max(scale_sym, eps); + zp = int((qmax + qmin + 1) >> 1); // mid-point – always fits } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index 48681a46c30..99a64c3589e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -31,12 +31,22 @@ $if MODE == "per_tensor": int quant_max; float eps; }; -$else: +$if MODE == "per_token": layout(push_constant) uniform restrict Block { int num_tokens; int quant_min; int quant_max; }; +$if MODE == "block_wise": + layout(push_constant) uniform BlockPC { + ivec4 blockSize; // WHCN (>=1) + ivec4 numBlocks; // #blocks along W,H,C,N + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP + int quant_min; + int quant_max; + float eps; + }; ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} ${layout_declare_ubo(B, "ivec4", "t_in_strides")} @@ -57,68 +67,133 @@ shared float shared_min[NWORKERS]; shared float shared_max[NWORKERS]; /* - * QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE) - * - * This shader computes quantization parameters (scale and zero_point) for converting - * floating-point tensors to n-bit integer representations while preserving the - * original data range as much as possible. - * - * ALGORITHM: - * 1. Find global min/max values across tensor elements using parallel reduction - * 2. Use tree reduction with shared memory for efficient min/max computation - * 3. Calculate scale = (max - min) / (quant_max - quant_min) - * 4. Calculate zero_point to map floating-point zero to integer value - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor) - * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) - * - Per-Token Mode: - * - Global WG Size: {num_tokens, 1, 1} (one workgroup per token) - * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) - * - * SUPPORTED CONFIGURATIONS: - * - Buffer Storage: Uses simple linear indexing through buffer elements - * - No axis mapping or packing considerations - processes elements sequentially - * - Works with any tensor layout since it accesses buffer data linearly - * - * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: - * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: - * - * Initial shared_min/shared_max arrays populated by each thread: - * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - * - * Stride 1 (compare pairs, keep min/max): - * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - * Active: | 0 | | 2 | | 4 | | 6 | | - * - * Stride 2 (compare pairs, keep min/max): - * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) - * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - * Active: | 0 | | | | 4 | | | | - * - * Stride 4 (final comparison): - * shared_min: | 0 | | | | | | | | (min(0,0) = 0) - * shared_max: | 10 | | | | | | | | (max(10,5) = 10) - * Active: | 0 | | | | | | | | - * - * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - * - * PER-TENSOR QUANTIZATION: - * - Single workgroup processes entire tensor with strided access - * - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...] - * - Tree reduction combines all thread results into global min/max - * - Output: Single scale and zero_point values - * - * PER-TOKEN QUANTIZATION: - * - Multiple workgroups, each processing one token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Each workgroup finds min/max within its assigned token - * - Output: Array of scale and zero_point values (one per token) - */ + Quantization Parameter Computation Shader (Buffer Storage) + This shader computes quantization parameters (scale and zero_point) for converting + floating-point tensors to n-bit integer representations while preserving the + original data range as much as possible. The computed parameters enable efficient + quantization by mapping the continuous floating-point range to discrete integer values. + + Important Considerations: + (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + + Workgroup Configuration: + - choose_qparams_per_tensor + This mode computes a single set of quantization parameters for the entire tensor. + Uses parallel reduction across all threads to find global min/max values. + + (*) global_wg_size: {1, 1, 1} (single workgroup processes entire tensor) + (*) local_wg_size: {64, 1, 1} (matches NWORKERS for shared memory) + + - choose_qparams_per_token + This mode computes separate quantization parameters for each token in the tensor. + Each workgroup processes one token independently to find token-specific min/max. + + (*) global_wg_size: {num_tokens, 1, 1} (one workgroup per token) + (*) local_wg_size: {1, 1, 1} (single thread per token) + + - choose_qparams_block_wise + This mode computes quantization parameters for each block of elements, allowing + fine-grained control over quantization granularity within the tensor. Each block + is processed independently to find its own min/max values and compute corresponding + scale and zero_point parameters. + + (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) + (*) local_wg_size: {1, 1, 1} (single thread per block) + + Block-wise quantization supports multiple mapping types for scale/zero_point calculation: + + - mapping_type = 0 (ASYMMETRIC): + Uses asymmetric quantization where the full floating-point range [min, max] is + mapped to the quantized range [quant_min, quant_max]. This preserves the original + data distribution but may not center zero optimally. + + Calculation: + scale = (max - min) / (quant_max - quant_min) + zero_point = quant_min - round(min / scale) + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + scale = (10.2 - (-3.5)) / (7 - (-8)) = 13.7 / 15 = 0.913 + zero_point = -8 - round(-3.5 / 0.913) = -8 - (-4) = -4 + + - mapping_type = 1 (SYMMETRIC): + Uses symmetric quantization where the range is centered around zero. The scale + is computed based on the maximum absolute value, ensuring zero is exactly + representable in the quantized domain. + + Calculation: + max_abs = max(abs(min), abs(max)) + scale = max_abs / ((quant_max - quant_min) / 2) + zero_point = (quant_max + quant_min + 1) / 2 // midpoint + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + max_abs = max(3.5, 10.2) = 10.2 + scale = 10.2 / ((7 - (-8)) / 2) = 10.2 / 7.5 = 1.36 + zero_point = (-8 + 7 + 1) / 2 = 0 + + - mapping_type = 2 (SYMMETRIC_NO_CLIPPING_ERR): + A variant of symmetric quantization that minimizes clipping errors by computing + separate scales for positive and negative ranges, then using the maximum. This + reduces quantization error on the dominant range while ensuring no values are + clipped. + + Calculation: + smin = abs(min) / abs(quant_min) // scale for negative range + smax = max / quant_max // scale for positive range + scale = max(smin, smax) // use larger scale to avoid clipping + zero_point = (quant_max + quant_min + 1) / 2 // midpoint + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + smin = 3.5 / 8 = 0.4375 + smax = 10.2 / 7 = 1.457 + scale = max(0.4375, 1.457) = 1.457 // use smax to avoid clipping positives + zero_point = (-8 + 7 + 1) / 2 = 0 + + Tree Reduction Algorithm for Min/Max Finding: + The shader uses a parallel tree reduction algorithm to efficiently find minimum and + maximum values across multiple threads. This approach reduces the number of memory + accesses and synchronization points compared to sequential scanning. + + Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: + + Step 1 - Initial Population: + Each thread loads its assigned value into shared memory arrays. + shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + + Step 2 - Stride 1 (Compare Adjacent Pairs): + Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. + shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + Active: | 0 | | 2 | | 4 | | 6 | | + + Step 3 - Stride 2 (Compare Pairs of Pairs): + Threads 0,4 compare with threads 2,6 respectively. + shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) + shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + Active: | 0 | | | | 4 | | | | + + Step 4 - Stride 4 (Final Comparison): + Thread 0 compares with thread 4 to get final result. + shared_min: | 0 | | | | | | | | (min(1,0) = 0) + shared_max: | 10 | | | | | | | | (max(10,5) = 10) + Active: | 0 | | | | | | | | + + Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + + The tree reduction completes in log_2(N) steps where N is the number of threads, + providing O(log N) time complexity instead of O(N) for sequential reduction. + + Quantization Parameter Calculation: + Once min/max values are determined, the shader computes: + - scale = (max - min) / (quant_max - quant_min) + - zero_point = quantization offset to map floating-point zero to integer range + + Mode-Specific Behavior: + - Per-Tensor: Single workgroup with strided access across entire tensor + - Per-Token: Multiple workgroups, each processing one token independently + - Block-Wise: Each thread processes assigned blocks using nested loops over block dimensions +*/ #ifdef per_tensor @@ -176,99 +251,141 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); + // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant + calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; } } -#else +#elif defined(per_token) void choose_qparams_per_token() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_workgroups = gl_NumWorkGroups.x; - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); uint token_size = total_elements / uint(num_tokens); - // Calculate how many tokens each workgroup should process - // This handles the case where we have more tokens than workgroups - uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; - - // Calculate which tokens this workgroup is responsible for - uint start_token = group_id * tokens_per_workgroup; - uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + const uint TOTAL_TOKENS = uint(num_tokens); - // Early exit if this workgroup has no tokens to process - if (start_token >= uint(num_tokens)) { - return; - } - - // Process each token assigned to this workgroup - for (uint token_id = start_token; token_id < end_token; token_id++) { + /* each invocation handles token-ids: id, id+STRIDE, id+2·STRIDE … */ + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + for (uint token_id = gl_GlobalInvocationID.x; token_id < TOTAL_TOKENS; token_id += STRIDE) { // Calculate the start and end indices for this token uint token_start = token_id * token_size; uint token_end = token_start + token_size; - // Each thread processes multiple elements within the token with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity + // Each thread processes the entire token + float lo = 1.0/0.0; // +INF + float hi = -1.0/0.0; // -INF bool found_valid = false; - // Process elements within this token only - for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + // Process all elements in this token + for (uint i = token_start; i < token_end; i++) { float val = t_in[i]; if (!isnan(val) && !isinf(val)) { if (!found_valid) { - thread_min = val; - thread_max = val; + lo = hi = val; found_valid = true; } else { - thread_min = min(thread_min, val); - thread_max = max(thread_max, val); + lo = min(lo, val); + hi = max(hi, val); } } } - // Intra-group reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); + if (!found_valid) { + // If no valid values were found, use default values + lo = 0.0; + hi = 0.0; + } - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; + // Calculate scale and zero point directly + float scale_val; + int zero_point_val; + // Use default values: mapping_type=0 (ASYMMETRIC), eps=1e-5 + calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; + // Write results + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } +} + +#elif defined(block_wise) + +ivec4 block_id_to_coord(uint bid) { + ivec4 bc; + bc.w = int(bid) / blockStride.w; + + int r = int(bid) - bc.w * blockStride.w; + bc.z = r / blockStride.z; + + r -= bc.z * blockStride.z; + bc.y = r / blockStride.y; + + r -= bc.y * blockStride.y; + bc.x = r; + return bc; +} + +void choose_qparams_block_wise() { + const uint TOTAL_BLOCKS = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); + + // each invocation handles block-ids: id, id+STRIDE, id+2·STRIDE + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE) { + // block -> WHCN coordinate + ivec4 bc = block_id_to_coord(block_id); + ivec4 blockStart = bc * blockSize; // first element (inclusive) + ivec4 blockEnd = blockStart + blockSize; // last element (exclusive) + + // min / max scan over the block + float lo = 1.0/0.0; // +INF + float hi = -1.0/0.0; // -INF + bool found_valid = false; + + // Calculate actual block dimensions + ivec4 actualBlockSize = blockEnd - blockStart; + int blockElements = actualBlockSize.x * actualBlockSize.y * actualBlockSize.z * actualBlockSize.w; + + // Linear iteration over block elements + for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { + // Convert linear index to 4D coordinates within block + int remaining = elemIdx; + int dn = remaining / (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); + remaining -= dn * (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); + int dc = remaining / (actualBlockSize.x * actualBlockSize.y); + remaining -= dc * (actualBlockSize.x * actualBlockSize.y); + int dh = remaining / actualBlockSize.x; + int dw = remaining - dh * actualBlockSize.x; + + ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); + uint idx = tidx_to_bufi(tidx, t_in_strides); + float v = t_in[idx]; + + if (!isnan(v) && !isinf(v)) { + if (!found_valid) { + lo = hi = v; + found_valid = true; + } else { + lo = min(lo, v); + hi = max(hi, v); } } - barrier(); } - // Final calculation for this token - if (local_id == 0) { - float token_min = shared_min[0]; - float token_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); - - t_scale[token_id] = scale_val; - t_zero_point[token_id] = zero_point_val; + // Handle the case where no valid values were found in the block + if (!found_valid) { + lo = 0.0; + hi = 0.0; } - // Synchronize before processing next token - barrier(); + float scale; + int zp; + calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp); + + t_zero_point[block_id] = zp; + t_scale[block_id] = scale; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml index c37039f68e9..ee900750e16 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -10,3 +10,5 @@ choose_qparams_buffer: MODE: per_tensor - NAME: choose_qparams_per_token_asymmetric_buffer MODE: per_token + - NAME: choose_qparams_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 5076b2d68e9..62ea7099f8c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -22,8 +22,13 @@ ${define_required_extensions(IN_DTYPE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} -${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +$if MODE != "block_wise": + ${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} + ${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +$else: + ${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": @@ -32,16 +37,33 @@ $if MODE == "per_tensor": int quant_max; float eps; }; -$else: +$if MODE == "per_token": layout(push_constant) uniform restrict Block { int num_tokens; int quant_min; int quant_max; }; +$if MODE == "block_wise": + layout(push_constant) uniform BlockPC { + ivec4 blockSize; // WHCN (>=1) + ivec4 numBlocks; // #blocks along W,H,C,N + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP + int quant_min; + int quant_max; + float eps; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_scale_limits")} -${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} +$if MODE != "block_wise": + ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} + ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} +$else: + ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} + #include "indexing_utils.h" #include "choose_qparams.glslh" @@ -54,73 +76,87 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; shared float shared_min[NWORKERS]; shared float shared_max[NWORKERS]; -/* - * QUANTIZATION PARAMETER COMPUTATION SHADER (TEXTURE STORAGE) - * - * This shader computes quantization parameters (scale and zero_point) for converting - * floating-point tensors to n-bit integer representations while preserving the - * original data range as much as possible. - * - * ALGORITHM: - * 1. Find global min/max values across tensor elements using parallel reduction - * 2. Use tree reduction with shared memory for efficient min/max computation - * 3. Calculate scale = (max - min) / (quant_max - quant_min) - * 4. Calculate zero_point to map floating-point zero to integer value - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: Default (typically {num_elements, 1, 1}) - * - Local WG Size: Default (typically {64, 1, 1}) - * - Per-Token Mode: - * - Global WG Size: Default (typically based on tensor dimensions) - * - Local WG Size: Default (typically {64, 1, 1}, or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with linear texel iteration - * - Assumes width-packed layout (packed_dim = 0) in current implementation - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - Note: Axis mapping support depends on indexing utilities - * - * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: - * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: - * - * Initial shared_min/shared_max arrays populated by each thread: - * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - * - * Stride 1 (compare pairs, keep min/max): - * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - * Active: | 0 | | 2 | | 4 | | 6 | | - * - * Stride 2 (compare pairs, keep min/max): - * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) - * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - * Active: | 0 | | | | 4 | | | | - * - * Stride 4 (final comparison): - * shared_min: | 0 | | | | | | | | (min(0,0) = 0) - * shared_max: | 10 | | | | | | | | (max(10,5) = 10) - * Active: | 0 | | | | | | | | - * - * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - * - * PER-TENSOR QUANTIZATION: - * - Single workgroup processes entire tensor - * - Each thread processes multiple texels with stride - * - Thread 0: texels [0, 64, 128, ...] -> elements [0-3, 256-259, 512-515, ...] - * - Thread 1: texels [1, 65, 129, ...] -> elements [4-7, 260-263, 516-519, ...] - * - Tree reduction combines all thread results into global min/max - * - Output: Single scale and zero_point values - * - * PER-TOKEN QUANTIZATION: - * - Multiple workgroups, each processing subset of tokens - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Each workgroup processes multiple tokens if num_tokens > num_workgroups - * - Within each token, threads process texels containing token elements - * - Output: Array of scale and zero_point values (one per token) - */ +/*/* + Quantization Parameter Computation Shader (Buffer Storage) + This shader computes quantization parameters (scale and zero_point) for converting + floating-point tensors to n-bit integer representations while preserving the + original data range as much as possible. The computed parameters enable efficient + quantization by mapping the continuous floating-point range to discrete integer values. + + Important Considerations: + (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + + Workgroup Configuration: + - choose_qparams_per_tensor + This mode computes a single set of quantization parameters for the entire tensor. + Uses parallel reduction across all threads to find global min/max values. + + (*) global_wg_size: default + (*) local_wg_size: default + + - choose_qparams_per_token + This mode computes separate quantization parameters for each token in the tensor. + Each workgroup processes one token independently to find token-specific min/max. + + (*) global_wg_size: default + (*) local_wg_size: {1, 1, 1} + + - choose_qparams_block_wise + This mode computes quantization parameters for each block of elements, allowing + fine-grained control over quantization granularity within the tensor. Each block + is processed independently to find its own min/max values and compute corresponding + scale and zero_point parameters. + + NOTE: This mode currently only supports buffer storage for the output. + + (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) + (*) local_wg_size: {1, 1, 1} (single thread per block) + + Tree Reduction Algorithm for Min/Max Finding: + The shader uses a parallel tree reduction algorithm to efficiently find minimum and + maximum values across multiple threads. This approach reduces the number of memory + accesses and synchronization points compared to sequential scanning. + + Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: + + Step 1 - Initial Population: + Each thread loads its assigned value into shared memory arrays. + shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + + Step 2 - Stride 1 (Compare Adjacent Pairs): + Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. + shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + Active: | 0 | | 2 | | 4 | | 6 | | + + Step 3 - Stride 2 (Compare Pairs of Pairs): + Threads 0,4 compare with threads 2,6 respectively. + shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) + shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + Active: | 0 | | | | 4 | | | | + + Step 4 - Stride 4 (Final Comparison): + Thread 0 compares with thread 4 to get final result. + shared_min: | 0 | | | | | | | | (min(1,0) = 0) + shared_max: | 10 | | | | | | | | (max(10,5) = 10) + Active: | 0 | | | | | | | | + + Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + + The tree reduction completes in log_2(N) steps where N is the number of threads, + providing O(log N) time complexity instead of O(N) for sequential reduction. + + Quantization Parameter Calculation: + Once min/max values are determined, the shader computes: + - scale = (max - min) / (quant_max - quant_min) + - zero_point = quantization offset to map floating-point zero to integer range + + Mode-Specific Behavior: + - Per-Tensor: Single workgroup with strided access across entire tensor + - Per-Token: Multiple workgroups, each processing one token independently +*/ #ifdef per_tensor @@ -235,14 +271,14 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); + calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); } } -#else +#elif defined(per_token) void choose_qparams_per_token() { // Each token is processed by multiple workgroups for parallel reduction @@ -373,7 +409,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); + calc_scale_zp(token_min, token_max, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions @@ -392,6 +428,100 @@ void choose_qparams_per_token() { } } +#elif defined(block_wise) + +ivec4 block_id_to_coord(uint bid) { + ivec4 bc; + bc.w = int(bid) / blockStride.w; + + int r = int(bid) - bc.w * blockStride.w; + bc.z = r / blockStride.z; + + r -= bc.z * blockStride.z; + bc.y = r / blockStride.y; + + r -= bc.y * blockStride.y; + bc.x = r; + return bc; +} + +void choose_qparams_block_wise() { + const uint T = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + + // tensor full size in WHCN order + const ivec4 tensorSz = blockSize * numBlocks; + + // Process blocks with stride for better parallelization + for (uint blkIdx = gl_GlobalInvocationID.x; blkIdx < T; blkIdx += STRIDE) { + // block index in WHCN + const ivec4 b4d = block_id_to_coord(blkIdx); + const ivec4 blockStart = b4d * blockSize; + const ivec4 blockEnd = blockStart + blockSize; + + // scan all elements inside the block + float vmin = 3.402823e38; // +FLT_MAX + float vmax = -3.402823e38; // -FLT_MAX + bool found_valid = false; + + // Calculate total elements in block for linear iteration + const int blockElements = blockSize.x * blockSize.y * blockSize.z * blockSize.w; + + // Linear iteration over block elements (more cache-friendly) + for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { + // Convert linear index to 4D coordinates within block + int remaining = elemIdx; + int dn = remaining / (blockSize.x * blockSize.y * blockSize.z); + remaining -= dn * (blockSize.x * blockSize.y * blockSize.z); + int dc = remaining / (blockSize.x * blockSize.y); + remaining -= dc * (blockSize.x * blockSize.y); + int dh = remaining / blockSize.x; + int dw = remaining - dh * blockSize.x; + + ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); + + // skip padding when tensor size is not an exact multiple of block + if (any(greaterThanEqual(tidx, tensorSz))) { continue; } + + // tensor index -> (x,y,z,component) inside input texture + ivec4 posi = to_texture_elem_pos(tidx, tensorSz, 0); // 0 = W_DIM (width packed) + + // fetch texel and pick the element inside it + FVEC4_T texl = load_texel(t_in, posi.xyz); + float v; + if (posi.w == 0) v = texl.x; + else if (posi.w == 1) v = texl.y; + else if (posi.w == 2) v = texl.z; + else v = texl.w; + + if (!isnan(v) && !isinf(v)) { + if (!found_valid) { + vmin = vmax = v; + found_valid = true; + } else { + vmin = min(vmin, v); + vmax = max(vmax, v); + } + } + } + + // Handle case where no valid values were found + if (!found_valid) { + vmin = 0.0; + vmax = 0.0; + } + + // compute scale / zero‑point (same maths as buffer kernel) + float scale; + int zp; + calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); + + // Write the scalar values directly to buffer using linear index + t_scale[blkIdx] = scale; + t_zero_point[blkIdx] = zp; + } +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml index f3961b87a0f..a097ce0da48 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -10,3 +10,5 @@ choose_qparams_texture: MODE: per_tensor - NAME: choose_qparams_per_token_asymmetric_texture3d MODE: per_token + - NAME: choose_qparams_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index de269920eea..76d352334e3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -14,45 +14,6 @@ namespace vkcompute { -namespace { - -void resize_choose_qparams_tensor_output( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - const ValueRef scale_out = args.at(0).refs.at(0); - const ValueRef zero_point_out = args.at(0).refs.at(1); - - // Both scale and zero_point are scalar tensors for per-tensor quantization - // Since we use single workgroup approach, no extra buffer space needed - graph->virtual_resize(scale_out, {}); - graph->virtual_resize(zero_point_out, {}); -} - -void resize_choose_qparams_per_token_output( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - const ValueRef scale_out = args.at(0).refs.at(0); - const ValueRef zero_point_out = args.at(0).refs.at(1); - const ValueRef input = args.at(1).refs.at(0); - - // Calculate output sizes for scale and zero_point tensors - const auto input_sizes = graph->sizes_of(input); - std::vector output_sizes; - output_sizes.reserve(input_sizes.size() - 1); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - output_sizes.push_back(input_sizes[i]); - } - output_sizes.push_back(1); - - graph->virtual_resize(scale_out, output_sizes); - graph->virtual_resize(zero_point_out, output_sizes); -} - -// Custom workgroup size pickers for ChooseQParams operations utils::uvec3 choose_qparams_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -135,15 +96,67 @@ utils::uvec3 choose_qparams_per_token_pick_local_wg_size( const ValueRef input = args.at(1).refs.at(0); if (graph->is_buffer_storage(input)) { - // For buffer storage, use 64 threads in X dimension to match NWORKERS - return {64u, 1u, 1u}; + return {1u, 1u, 1u}; } else { // For texture storage, use the default logic return graph->create_local_wg_size(global_workgroup_size); } } -} // namespace +utils::uvec3 choose_qparams_block_wise_pick_global_wg_size( + ComputeGraph* g, + const vkapi::ShaderInfo&, + const std::vector& a, + const std::vector& r) { + const ValueRef input = a.at(2).refs.at(0); + const auto blkRef = r.at(0); + const auto inSz = g->sizes_of(input); + const auto blkList = g->get_int_list(blkRef); + + // Use same code as in add_choose_qparams_block_wise_node + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*blkList); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(inSz); + + // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) + utils::ivec4 nBlk = { + (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], + (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], + (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], + (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; + + uint32_t nBlocks = nBlk[0] * nBlk[1] * nBlk[2] * nBlk[3]; + + // For texture storage, use more threads to better utilize GPU parallelism + // Each thread can process multiple blocks with stride + if (g->is_buffer_storage(input)) { + return {nBlocks, 1u, 1u}; + } else { + // For texture storage, use more workgroups to better utilize GPU + // Aim for ~64-256 threads per workgroup for good occupancy + uint32_t preferred_threads_per_wg = 64; + uint32_t num_workgroups = + (nBlocks + preferred_threads_per_wg - 1) / preferred_threads_per_wg; + num_workgroups = std::max(1u, std::min(num_workgroups, nBlocks)); + return {num_workgroups * preferred_threads_per_wg, 1u, 1u}; + } +} + +utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( + ComputeGraph* g, + const vkapi::ShaderInfo&, + const utils::uvec3& global_wg_size, + const std::vector& a, + const std::vector&) { + const ValueRef input = a.at(2).refs.at(0); + + if (g->is_buffer_storage(input)) { + return {1u, 1u, 1u}; + } else { + // For texture storage, use 64 threads per workgroup for better occupancy + uint32_t local_size = std::min(64u, global_wg_size[0]); + return {local_size, 1u, 1u}; + } +} void add_choose_qparams_tensor_node( ComputeGraph& graph, @@ -162,6 +175,7 @@ void add_choose_qparams_tensor_node( float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; + std::vector push_constants; if (graph.is_buffer_storage(input)) { param_ubos = { @@ -178,7 +192,6 @@ void add_choose_qparams_tensor_node( graph.logical_limits_ubo(zero_point_out)}; } - std::vector push_constants; push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), @@ -203,7 +216,7 @@ void add_choose_qparams_tensor_node( // Resize Args {}, // Resizing Logic - resize_choose_qparams_tensor_output)); + nullptr)); } void add_choose_qparams_per_token_asymmetric_node( @@ -227,6 +240,7 @@ void add_choose_qparams_per_token_asymmetric_node( int quant_max_val = 127; // Fixed for asymmetric quantization vkapi::ParamsBindList param_ubos; + std::vector push_constants; if (graph.is_buffer_storage(input)) { param_ubos = { @@ -243,7 +257,6 @@ void add_choose_qparams_per_token_asymmetric_node( graph.logical_limits_ubo(zero_point_out)}; } - std::vector push_constants; push_constants = { PushConstantDataInfo(&num_tokens_val, sizeof(int)), PushConstantDataInfo(&quant_min_val, sizeof(int)), @@ -268,7 +281,100 @@ void add_choose_qparams_per_token_asymmetric_node( // Resize Args {}, // Resizing Logic - resize_choose_qparams_per_token_output)); + nullptr)); +} + +void add_choose_qparams_block_wise_node( + ComputeGraph& graph, + ValueRef input, + ValueRef block_size, + int mapping_type, // 0 / 1 / 2 + ValueRef quant_min, + ValueRef quant_max, + ValueRef eps, + ValueRef scale_out, + ValueRef zp_out) { + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // For shader compatibility, we still need to convert to WHCN order + // but the output shape calculation is now handled correctly in resize + // function + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) + utils::ivec4 num_blocks_vec = { + (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], + (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], + (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], + (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + int qmin = static_cast(graph.get_int(quant_min)); + int qmax = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); + + // Create push constants vector + std::vector push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&mapping_type, sizeof(int)), + PushConstantDataInfo(&qmin, sizeof(int)), + PushConstantDataInfo(&qmax, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float))}; + + std::string kernel_name("choose_qparams_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zp_out), + graph.strides_ubo(zp_out)}; + } else { + // For texture input, the shader uses buffer storage for outputs + // so we need buffer UBOs for the output tensors + param_ubos = { + graph.logical_limits_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zp_out), + graph.strides_ubo(zp_out)}; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_block_wise_pick_global_wg_size, + choose_qparams_block_wise_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zp_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {block_size}, + // Resizing Logic + nullptr)); } void choose_qparams_tensor_impl( @@ -278,9 +384,8 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) - const ValueRef dtype = - args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef eps = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; ValueRef scale_out = kDummyValueRef; @@ -301,17 +406,11 @@ void choose_qparams_tensor_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -327,8 +426,7 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef dtype = - args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; ValueRef scale_out = kDummyValueRef; @@ -349,17 +447,16 @@ void choose_qparams_per_token_asymmetric_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); @@ -370,9 +467,8 @@ void choose_qparams_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef mapping_type = args[arg_idx++]; + const ValueRef block_size = args[arg_idx++]; const ValueRef target_dtype = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; @@ -382,7 +478,6 @@ void choose_qparams_affine_impl( const ValueRef out_tuple_ref = args[arg_idx++]; // Suppress unused variable warnings - (void)mapping_type; (void)target_dtype; (void)scale_dtype; (void)zero_point_dtype; @@ -402,36 +497,42 @@ void choose_qparams_affine_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); - } - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + std::string mapping_type_str = graph.get_string(mapping_type); + int mapping_type_val = 0; // Default to ASYMMETRIC + + if (mapping_type_str == "ASYMMETRIC") { + mapping_type_val = 0; + } else if (mapping_type_str == "SYMMETRIC") { + mapping_type_val = 1; + } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { + mapping_type_val = 2; } - // Default to per-tensor quantization parameter calculation for TorchAO affine - // ops - add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); + add_choose_qparams_block_wise_node( + graph, + input, + block_size, + mapping_type_val, + quant_min, + quant_max, + eps, + scale_out, + zero_point_out); } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/op_tests/quantize_affine_test.cpp b/backends/vulkan/test/op_tests/quantize_affine_test.cpp index 8a54774d703..d2a971da82b 100644 --- a/backends/vulkan/test/op_tests/quantize_affine_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_affine_test.cpp @@ -279,6 +279,134 @@ at::Tensor dequantize_affine_reference_impl( std::string("INT")); } +std::tuple choose_qparams_affine_reference_impl( + const at::Tensor& input_, + const std::string& mapping_type, + const std::vector& block_size, + int64_t quant_min, + int64_t quant_max, + double eps) { + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kFloat || input_.scalar_type() == at::kHalf || + input_.scalar_type() == at::kBFloat16, + "Unsupported input dtype: ", + input_.dtype()); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor min_val = input_reshaped.amin(reduction_dims, /*keepdim=*/true); + at::Tensor max_val = input_reshaped.amax(reduction_dims, /*keepdim=*/true); + + at::Tensor scale, zero_point; + + if (mapping_type == "ASYMMETRIC") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate scale + scale = (max_val - min_val) / (quant_max - quant_min); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point + zero_point = at::round(quant_min - min_val / scale); + zero_point = at::clamp(zero_point, quant_min, quant_max); + } else if (mapping_type == "SYMMETRIC") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate max absolute value + at::Tensor abs_min = at::abs(min_val); + at::Tensor abs_max = at::abs(max_val); + at::Tensor M = at::maximum(abs_min, abs_max); + + // Calculate scale + scale = M / ((quant_max - quant_min) * 0.5); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point (mid-point) + zero_point = + at::full_like(scale, (quant_max + quant_min + 1) / 2, at::kInt); + } else if (mapping_type == "SYMMETRIC_NO_CLIPPING_ERR") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate scale based on min/max values + at::Tensor s_min = at::abs(min_val) / std::abs(quant_min); + at::Tensor s_max = max_val / quant_max; + scale = at::maximum(s_min, s_max); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point (mid-point) + zero_point = + at::full_like(scale, (quant_max + quant_min + 1) / 2, at::kInt); + } else { + VK_CHECK_COND( + false, + "Unsupported mapping_type: ", + mapping_type, + ". Expected ASYMMETRIC, SYMMETRIC, or SYMMETRIC_NO_CLIPPING_ERR"); + } + + std::vector output_shape; + for (size_t i = 0; i < shape_after_reduction.size(); ++i) { + if (shape_after_reduction[i] != 1 || + std::find(reduction_dims.begin(), reduction_dims.end(), i) == + reduction_dims.end()) { + output_shape.push_back(shape_after_reduction[i]); + } + } + + // Reshape scale and zero_point to final output shape + scale = scale.view(output_shape); + zero_point = zero_point.view(output_shape); + + return std::make_tuple(scale, zero_point); +} + void test_vulkan_quantize_affine_impl( const std::vector& input_sizes, const std::vector& block_size, @@ -857,3 +985,395 @@ TEST(VulkanDequantizeAffineTest, test_4d_dequantization) { at::kChar, // input dtype at::kFloat); // output dtype } + +void test_vulkan_choose_qparams_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::string& mapping_type, + int64_t quant_min, + int64_t quant_max, + double eps, + at::ScalarType in_dtype = at::kFloat, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kBuffer) { + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Get reference output + auto reference_out = choose_qparams_affine_reference_impl( + input, mapping_type, block_size, quant_min, quant_max, eps); + + at::Tensor reference_scale = std::get<0>(reference_out); + at::Tensor reference_zero_point = std::get<1>(reference_out); + + reference_zero_point = reference_zero_point.to(at::kInt); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Create mapping_type as string + std::string mapping_type_copy = mapping_type; + const ValueRef r_mapping_type = + graph.add_string(std::move(mapping_type_copy)); + + // Create block_size as IntList + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + // Create target_dtype, quant_min, quant_max, eps + const ValueRef r_target_dtype = + graph.add_scalar(static_cast(at::kChar)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_eps = graph.add_scalar(eps); + + // Create scale_dtype and zero_point_dtype + const ValueRef r_scale_dtype = + graph.add_scalar(static_cast(at::kFloat)); + const ValueRef r_zero_point_dtype = + graph.add_scalar(static_cast(at::kInt)); + + // Create output tuple + std::vector out_tuple; + + // Create scale and zero_point output tensors + const ValueRef r_scale_out = graph.add_tensor( + reference_scale.sizes().vec(), vkapi::kFloat, out_storage); + const ValueRef r_zero_point_out = graph.add_tensor( + reference_zero_point.sizes().vec(), vkapi::kInt, out_storage); + + out_tuple.push_back(r_scale_out); + out_tuple.push_back(r_zero_point_out); + + const ValueRef r_out_tuple = graph.add_value_list(std::move(out_tuple)); + + VK_GET_OP_FN("torchao.choose_qparams_affine.default") + (graph, + { + r_input.value, + r_mapping_type, + r_block_size, + r_target_dtype, + r_quant_min, + r_quant_max, + r_eps, + r_scale_dtype, + r_zero_point_dtype, + r_out_tuple, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale_out); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_scale = at::empty_like(reference_scale).contiguous(); + at::Tensor vk_zero_point = at::empty_like(reference_zero_point).contiguous(); + + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Compare outputs + const bool scale_correct = + at::allclose(reference_scale, vk_scale, /*rtol=*/1e-3, /*atol=*/1e-3); + + // For zero point, we need to compare as integers since zero point should be + // an integer First convert both tensors to int if they aren't already + at::Tensor ref_zp_int = reference_zero_point.to(at::kInt); + at::Tensor vk_zp_int = vk_zero_point.to(at::kInt); + const bool zero_point_correct = at::equal(ref_zp_int, vk_zp_int); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " mapping_type: " << mapping_type << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " eps: " << eps << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + if (!scale_correct || !zero_point_correct) { + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + + std::cout << "reference_scale:" << std::endl + << reference_scale << std::endl; + std::cout << "vulkan_scale:" << std::endl << vk_scale << std::endl; + + std::cout << "reference_zero_point:" << std::endl + << reference_zero_point << std::endl; + std::cout << "vulkan_zero_point:" << std::endl + << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct); + ASSERT_TRUE(zero_point_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::string& mapping_type, + int64_t quant_min, + int64_t quant_max, + double eps, + at::ScalarType in_dtype = at::kFloat) { + // Test with buffer storage for both input and output + test_vulkan_choose_qparams_affine_impl( + input_sizes, + block_size, + mapping_type, + quant_min, + quant_max, + eps, + in_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage for input and buffer storage for output + // (shader always uses buffer storage for outputs) + test_vulkan_choose_qparams_affine_impl( + input_sizes, + block_size, + mapping_type, + quant_min, + quant_max, + eps, + in_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kBuffer); +} + +TEST(VulkanChooseQParamsAffineTest, test_1d_asymmetric) { + // 1D: 12 Tensor, block_size is 3 + test_vulkan_choose_qparams_affine( + {12}, // input_sizes + {3}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_2d_symmetric) { + // 2D: 8x6 Tensor, block_size is 2x3 + test_vulkan_choose_qparams_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size + "SYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_3d_symmetric_no_clipping) { + // 3D: 6x4x6 Tensor, block_size is 3x2x2 + test_vulkan_choose_qparams_affine( + {6, 4, 6}, // input_sizes + {3, 2, 2}, // block_size + "SYMMETRIC_NO_CLIPPING_ERR", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_4d_asymmetric) { + // 4D: 4x6x6x6 Tensor, block_size is 2x3x2x3 + test_vulkan_choose_qparams_affine( + {4, 6, 6, 6}, // input_sizes (reduced from 8 to 4 to make test faster) + {2, 3, 2, 3}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_tensor) { + // Per-tensor: block_size equals tensor size + test_vulkan_choose_qparams_affine( + {4, 6, 8}, // input_sizes + {4, 6, 8}, // block_size equals tensor size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_token) { + // Per-token: block_size is all 1s except last dimension + test_vulkan_choose_qparams_affine( + {4, 6, 8}, // input_sizes + {1, 1, 8}, // block_size is all 1s except last dimension + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +// Additional tests for choose_qparams_affine + +TEST(VulkanChooseQParamsAffineTest, test_uint8_range) { + // Test with uint8 range (0-255) + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + 0, // quant_min (uint8 min) + 255, // quant_max (uint8 max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_int16_range) { + // Test with int16 range (-32768 to 32767) + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC", // mapping_type + -32768, // quant_min (int16 min) + 32767, // quant_max (int16 max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_larger_eps) { + // Test with larger epsilon value + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-2, // larger eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_channel_first_dim) { + // Per-channel quantization on first dimension + test_vulkan_choose_qparams_affine( + {8, 6, 4}, // input_sizes + {1, 6, 4}, // block_size (per-channel on dim 0) + "SYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_channel_middle_dim) { + // Per-channel quantization on middle dimension + test_vulkan_choose_qparams_affine( + {4, 8, 6}, // input_sizes + {4, 1, 6}, // block_size (per-channel on dim 1) + "SYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_mixed_block_sizes) { + // Mixed block sizes (some dimensions fully quantized, some partially) + test_vulkan_choose_qparams_affine( + {8, 6, 10}, // input_sizes + {4, 6, 2}, // block_size (mixed: partial, full, partial) + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_small_tensor) { + // Test with a small tensor + test_vulkan_choose_qparams_affine( + {2, 3}, // small input_sizes + {2, 3}, // block_size (full tensor) + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_asymmetric_narrow_range) { + // Test with a narrow quantization range + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_symmetric_narrow_range) { + // Test with a narrow quantization range with symmetric mapping + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_symmetric_no_clipping_narrow_range) { + // Test with a narrow quantization range with symmetric no clipping mapping + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC_NO_CLIPPING_ERR", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} \ No newline at end of file