diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index 805a5c1f744..aa4829d9c90 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -210,6 +210,278 @@ def fuse_into_linear_qcnw_node( graph_module.graph.erase_node(dq_weight_node) +######################### +## linear_qta8a_qga4w ## +######################### + + +def _is_dequantize_affine_node(node: torch.fx.Node) -> bool: + """Check if a node is a dequantize_affine operation.""" + return ( + node.op == "call_function" + and node.target is not None + and hasattr(node.target, "__name__") + and "dequantize_affine" in getattr(node.target, "__name__", "") + ) + + +def _is_view_copy_node(node: torch.fx.Node) -> bool: + """Check if a node is a view_copy operation.""" + return ( + node.op == "call_function" + and node.target is not None + and hasattr(node.target, "__name__") + and "view_copy" in getattr(node.target, "__name__", "") + ) + + +def _validate_qta8a_qga4w_nodes( + input_node: torch.fx.node.Argument, weight_node: torch.fx.node.Argument +) -> Optional[torch.fx.Node]: + """ + Validate input and weight nodes for QTA8A_QGA4W pattern. + Returns the actual input node (after handling view operations) or None if invalid. + """ + # Type checking - ensure we have torch.fx.Node objects + if not isinstance(weight_node, torch.fx.Node) or not isinstance( + input_node, torch.fx.Node + ): + return None + + # Input may be preprocessed with a view node + actual_input_node = input_node + if _is_view_copy_node(input_node): + actual_input_node = input_node.args[0] + if not isinstance(actual_input_node, torch.fx.Node): + return None + + # Check if input is dequantized with dequantize_affine (from dynamic quantization) + if not _is_dequantize_affine_node(actual_input_node): + return None + + # Check if weight is dequantized with dequantize_affine + if not _is_dequantize_affine_node(weight_node): + return None + + return actual_input_node + + +def _extract_weight_params( + program: ExportedProgram, weight_node: torch.fx.Node +) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]: + """Extract and validate weight parameters from dequantize_affine node.""" + # Get the original quantized weight and quantization parameters + if len(weight_node.args) < 4: + return None + + orig_weight = weight_node.args[0] + weight_scales = weight_node.args[2] + weight_zeros = weight_node.args[3] + + # Type checking + if not isinstance(orig_weight, torch.fx.Node) or not is_param_node( + program, orig_weight + ): + return None + if not isinstance(weight_scales, torch.fx.Node) or not is_param_node( + program, weight_scales + ): + return None + if not isinstance(weight_zeros, torch.fx.Node) or not is_param_node( + program, weight_zeros + ): + return None + + return orig_weight, weight_scales, weight_zeros + + +def _validate_4bit_quantization(weight_tensor: torch.Tensor) -> bool: + """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range).""" + quant_min = weight_tensor.min().item() + quant_max = weight_tensor.max().item() + return quant_min >= -8 and quant_max <= 7 + + +def _calculate_group_size( + orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor +) -> Optional[int]: + """Calculate and validate group size from weight and scales tensors.""" + out_features, in_features = orig_weight_tensor.shape + + if len(weight_scales_tensor.shape) != 2: + return None + + scales_out_features, num_groups = weight_scales_tensor.shape + + if scales_out_features != out_features: + return None + + group_size = in_features // num_groups + if in_features % group_size != 0: + return None + + return group_size + + +def matches_linear_qta8a_qga4w_pattern( + program: ExportedProgram, node: torch.fx.Node +) -> Optional[Tuple[int, int]]: + """ + Checks if the nodes surrounding a linear node matches the pattern for dynamic + activation + grouped weight quantized linear (QTA8A_QGA4W). + + This pattern involves: + 1. Dynamic quantization of input activations (8-bit) + 2. Grouped quantization of weights (4-bit with group size) + + The expected pattern from Int8DynActInt4WeightQuantizer is: + scale, zero_point = choose_qparams_affine(input) + quantized_input = quantize_affine(input, scale, zero_point) + dequantized_input = dequantize_affine(quantized_input, ...) + dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros) + output = linear(dequantized_input, dequantized_weight) + + If the pattern matches, return (group_size, weight_bits), otherwise None. + """ + if not utils.is_linear_node(node): + return None + + input_node = node.args[0] + weight_node = node.args[1] + + # Validate nodes and get actual input node + actual_input_node = _validate_qta8a_qga4w_nodes(input_node, weight_node) + if actual_input_node is None: + return None + + # Extract weight parameters + if not isinstance(weight_node, torch.fx.Node): + return None + weight_params = _extract_weight_params(program, weight_node) + if weight_params is None: + return None + + orig_weight, weight_scales, weight_zeros = weight_params + + # Get tensors to analyze the quantization scheme + orig_weight_tensor = get_param_tensor(program, orig_weight) + weight_scales_tensor = get_param_tensor(program, weight_scales) + weight_zeros_tensor = get_param_tensor(program, weight_zeros) + + if not isinstance(orig_weight_tensor, torch.Tensor): + return None + if not isinstance(weight_scales_tensor, torch.Tensor): + return None + if not isinstance(weight_zeros_tensor, torch.Tensor): + return None + + # Check if weight is quantized to 4 bits + if not _validate_4bit_quantization(orig_weight_tensor): + return None + + # Calculate group size + group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor) + if group_size is None: + return None + + # Verify this is 4-bit grouped quantization + weight_bits = 4 + + return group_size, weight_bits + + +def fuse_into_linear_qta8a_qga4w_node( + program: ExportedProgram, + graph_module: torch.fx.GraphModule, + linear_node: torch.fx.Node, + group_size: int, + weight_bits: int, +) -> None: + """ + Fuse the dynamic activation + grouped weight quantized linear pattern into + a single linear_qta8a_qga4w operator. + + The pattern: + dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...) + dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...) + output = linear(dequantized_input, dequantized_weight) + + Becomes: + output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point, + weight, group_size, weight_scales, weight_zeros) + """ + dq_input_node = linear_node.args[0] + dq_weight_node = linear_node.args[1] + + assert isinstance(dq_input_node, torch.fx.Node) + + input_view_node = None + # Input may be preprocessed with a view node + if ( + dq_input_node.op == "call_function" + and dq_input_node.target is not None + and hasattr(dq_input_node.target, "__name__") + and "view_copy" in getattr(dq_input_node.target, "__name__", "") + ): + input_view_node = dq_input_node + dq_input_node = dq_input_node.args[0] + assert isinstance(dq_input_node, torch.fx.Node) + + assert isinstance(dq_input_node, torch.fx.Node) + assert isinstance(dq_weight_node, torch.fx.Node) + + # Get the quantized input and quantization parameters from the input dequantize_affine node + # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype) + quantized_input = dq_input_node.args[0] + input_scale = dq_input_node.args[2] # scale is the 3rd argument + input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None + + # Get the weight and its quantization parameters from dequantize_affine + # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype) + orig_weight = dq_weight_node.args[0] + weight_scales = dq_weight_node.args[2] + weight_zeros = dq_weight_node.args[3] + + # Pack the 4-bit weight tensor for efficient storage + assert isinstance(orig_weight, torch.fx.Node) + orig_weight_tensor = get_param_tensor(program, orig_weight) + assert isinstance(orig_weight_tensor, torch.Tensor) + packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) + utils.update_program_state_dict( + program, + orig_weight.name, + packed_weight_tensor, + ) + # Update the metadata to reflect the new packed shape + orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) + + # Create the linear_qta8a_qga4w node + with graph_module.graph.inserting_before(linear_node): + linear_qta8a_qga4w_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_qta8a_qga4w.default, + ( + quantized_input, # quantized input (int8) + input_scale, # mat1_scale + input_zero_point, # mat1_zero_point + orig_weight, # mat2_data (packed 4-bit weights) + group_size, # group_size (int) + weight_scales, # weight_scales + weight_zeros, # weight_zeros + ), + ) + + # Replace the linear node with the new fused node + linear_node.replace_all_uses_with(linear_qta8a_qga4w_node) + + # Erase nodes in the correct order (users first, then dependencies) + graph_module.graph.erase_node(linear_node) + if input_view_node is not None: + graph_module.graph.erase_node(input_view_node) + graph_module.graph.erase_node(dq_weight_node) + graph_module.graph.erase_node(dq_input_node) + + class FuseQuantizedOpsTransform(ExportPass): def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() @@ -217,12 +489,23 @@ def __init__(self, exported_program: ExportedProgram) -> None: def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.nodes: + # Check for linear_qcnw pattern (weight-only quantization) qcnw_details = matches_linear_qcnw_pattern(self.program, node) if qcnw_details is not None: qcnw_method, qcnw_nbits = qcnw_details fuse_into_linear_qcnw_node( self.program, graph_module, node, qcnw_method, qcnw_nbits ) + continue + + # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization) + qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node) + if qta8a_qga4w_details is not None: + group_size, weight_bits = qta8a_qga4w_details + fuse_into_linear_qta8a_qga4w_node( + self.program, graph_module, node, group_size, weight_bits + ) + continue graph_module.recompile() dead_code_elimination_pass(graph_module) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index af6fcbfbb14..c9b884e5b86 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -231,6 +231,95 @@ def linear_qcs4w( lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd") linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) +######################## +## linear_qta8a_qga4w ## +######################## + + +def linear_qta8a_qga4w( + x_quantized: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weights_4bit: torch.Tensor, + group_size: int, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, +): + """ + Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W). + + Args: + x_quantized: Already quantized input tensor (int8, per-token quantized) + input_scale: Scale for per-token quantization of input (shape: [batch_size]) + input_zero_point: Zero point for per-token quantization of input (shape: [batch_size]) + weights_4bit: Packed 4-bit quantized weights + group_size: Group size for weight quantization (int) + weight_scales: Per-group scales for weights + weight_zeros: Per-group zero points for weights + """ + original_x_shape = x_quantized.shape + feature_dim = original_x_shape[-1] + + # Reshape for processing + x_quantized_2d = x_quantized.reshape(-1, feature_dim) + + # Unpack 4-bit weights + unpacked_weights_shape = weights_4bit.shape + out_features = unpacked_weights_shape[0] + in_features = unpacked_weights_shape[1] + + weights_unpacked = torch.empty( + (out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device + ) + + weights_unpacked[:, ::2] = weights_4bit >> 4 + weights_unpacked[:, 1::2] = weights_4bit & 0x0F + + # Convert to signed 4-bit range [-8, 7] + weights_unpacked = torch.where( + weights_unpacked > 7, weights_unpacked - 16, weights_unpacked + ) + + # Dequantize weights using grouped quantization + actual_in_features = in_features * 2 + num_groups = actual_in_features // group_size + + # Reshape weights for grouped dequantization + weights_grouped = weights_unpacked.view(out_features, num_groups, group_size) + + # Expand scales and zeros to match grouped weights + scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size) + zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size) + + # Dequantize: (quantized - zero_point) * scale + dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded + dq_weights = dq_weights_grouped.view(out_features, actual_in_features) + + # Dequantize input (per-token) + # For per-token quantization, each token (row) has its own scale and zero_point + x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token( + x_quantized_2d, + input_scale, + input_zero_point, + -128, + 127, + torch.int8, + torch.float32, + ) + + # Perform linear operation + out = torch.nn.functional.linear(x_dequantized, dq_weights) + out_shape = original_x_shape[:-1] + (out_features,) + return out.reshape(out_shape) + + +name = "linear_qta8a_qga4w" +lib.define( + f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor" +) +lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd") +linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 19594002cf2..33ed3150535 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -245,9 +245,9 @@ def register_ephemeral_op(features: OpFeatures): @update_features( [ - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, @@ -276,14 +276,32 @@ def register_quantization_op(features: OpFeatures): [ exir_ops.edge.torchao.quantize_affine.default, exir_ops.edge.torchao.dequantize_affine.default, + ] +) +def register_affine_quantization_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_axis_map=False, + valid_packed_dims={PackedDim.WIDTH}, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED + features.handles_own_prepacking = True + + return features + + +@update_features( + [ exir_ops.edge.torchao.choose_qparams_affine.default, ] ) -def register_torchao_quantization_op(features: OpFeatures): - # TorchAO quantization operators - default to per-tensor behavior - # Same features as standard quantization ops +def register_choose_qparams_affine_op(features: OpFeatures): + # Currently only created a rudimentary buffer implementation for choose_qparams_affine + # since the reduction logic for blocks in texture3d is not trivial to implement in vulkan. features.texture_impl = TextureImplFeatures( - uses_axis_map=True, + uses_axis_map=False, valid_packed_dims={ PackedDim.WIDTH, }, @@ -292,37 +310,6 @@ def register_torchao_quantization_op(features: OpFeatures): features.resize_fn = True features.optimal_storage = VkStorageType.BUFFER - def check_torchao_quantization_node(node: torch.fx.Node) -> bool: - # Only per-tensor quantization is supported by the Vulkan backend. - if len(node.args) < 2: - return False - - block_size = node.args[1] - - if not isinstance(block_size, (list, tuple)): - return False - - input_arg = node.args[0] - if not isinstance(input_arg, torch.fx.Node): - return False - - input_tensor = input_arg.meta.get("val", None) - if not isinstance(input_tensor, FakeTensor): - return False - - input_shape = list(input_tensor.shape) - - if len(block_size) != len(input_shape): - return False - - # Check if block_size matches input_shape exactly (per-tensor quantization) - for i in range(len(block_size)): - if block_size[i] != input_shape[i]: - return False - - return True - - features.check_node_fn = check_torchao_quantization_node return features @@ -487,7 +474,12 @@ def register_int8_mm_op(features: OpFeatures): return features -@update_features(exir_ops.edge.et_vk.linear_weight_int4.default) +@update_features( + [ + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.et_vk.linear_qta8a_qga4w.default, + ] +) def register_int4_mm_op(features: OpFeatures): features.buffer_impl = True features.texture_impl = TextureImplFeatures( diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 6e11c36bfb0..40212c35c27 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -14,11 +14,12 @@ import torch from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import ( _convert_scalars_to_attrs, + bits_to_range, OP_TO_ANNOTATOR, propagate_annotation, ) from torch.fx import Node -from torchao.quantization.pt2e import PerChannelMinMaxObserver +from torchao.quantization.pt2e import PerChannelMinMaxObserver, PlaceholderObserver from torchao.quantization.pt2e.quantizer import ( QuantizationConfig, QuantizationSpec, @@ -28,50 +29,86 @@ __all__ = [ "VulkanQuantizer", - "get_linear_weight_qcs_qspec", - "get_linear_weight_only_qcs_xnn_qconfig", + "get_symmetric_quantization_config", ] -def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec: +@functools.lru_cache +def get_symmetric_quantization_config( + is_dynamic: bool = False, + weight_bits: int = 8, + act_bits: int = 8, + act_qmin: Optional[int] = None, + act_qmax: Optional[int] = None, + weight_qmin: Optional[int] = None, + weight_qmax: Optional[int] = None, +) -> QuantizationConfig: """ - Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization - of weight tensors of linear layers to the number of bits specified by quant_bits. + Return a QuantizationConfig for Vulkan quantizer. + + Args: + is_dynamic: If False, weight-only quantization. If True, dynamic quantization (activation + weight) + weight_bits: Number of bits for weight quantization (4 or 8) + act_bits: Number of bits for activation quantization (8) + act_qmin: Minimum quantization value for activations (auto-calculated if None) + act_qmax: Maximum quantization value for activations (auto-calculated if None) + weight_qmin: Minimum quantization value for weights (auto-calculated if None) + weight_qmax: Maximum quantization value for weights (auto-calculated if None) """ - weight_observer = PerChannelMinMaxObserver - assert quant_bits in { + assert weight_bits in { 8, 4, - }, f"Unsupported weight quantization bits: {quant_bits}" + }, f"Unsupported weight quantization bits: {weight_bits}" + + assert act_bits in { + 8, + }, f"Unsupported activation quantization bits: {act_bits}" - quant_min = -(2 ** (quant_bits - 1)) - quant_max = 2 ** (quant_bits - 1) - 1 - qscheme = torch.per_channel_symmetric + # Auto-calculate weight ranges if not provided + if weight_qmin is None or weight_qmax is None: + weight_range = bits_to_range(weight_bits) + weight_qmin = weight_qmin if weight_qmin is not None else weight_range[0] + weight_qmax = weight_qmax if weight_qmax is not None else weight_range[1] - return QuantizationSpec( + # Weight quantization: per-channel symmetric for Vulkan + weight_quantization_spec = QuantizationSpec( dtype=torch.int8, - quant_min=quant_min, - quant_max=quant_max, - qscheme=qscheme, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=torch.per_channel_symmetric, ch_axis=0, is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, ) - -@functools.lru_cache -def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig: - """ - Return a XNNPACKQuantizer QuantizationConfig class instance that specifies - quantizing the weight tensors of linear layers using per-channel symmetric (qcs) - quantization to the number of bits specified by quant_bits. - """ - weight_qspec = get_linear_weight_qcs_qspec(quant_bits) + # Configure activation quantization based on is_dynamic + if not is_dynamic: + # Weight-only quantization: no activation quantization + act_quantization_spec = None + output_activation_spec = None + else: + # Dynamic quantization: per-token input quantization, no output quantization + # Auto-calculate activation ranges if not provided + if act_qmin is None or act_qmax is None: + act_range = bits_to_range(act_bits) + act_qmin = act_qmin if act_qmin is not None else act_range[0] + act_qmax = act_qmax if act_qmax is not None else act_range[1] + + act_observer_or_fake_quant_ctr = PlaceholderObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=True, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr, + ) + output_activation_spec = None return QuantizationConfig( - input_activation=None, - output_activation=None, - weight=weight_qspec, + input_activation=act_quantization_spec, + output_activation=output_activation_spec, + weight=weight_quantization_spec, bias=None, is_qat=False, ) @@ -99,12 +136,11 @@ def transform_for_annotation( return _convert_scalars_to_attrs(model) def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - # currently only support static quant on Vulkan - model = self._annotate_for_static_quantization_config(model) + model = self._annotate_for_quantization_config(model) propagate_annotation(model) return model - def _annotate_all_static_patterns( + def _annotate_all_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], @@ -117,10 +153,10 @@ def _annotate_all_static_patterns( OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) return model - def _annotate_for_static_quantization_config( + def _annotate_for_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - self._annotate_all_static_patterns( + self._annotate_all_patterns( model, self.global_config, ) diff --git a/backends/vulkan/quantizer/vulkan_quantizer_utils.py b/backends/vulkan/quantizer/vulkan_quantizer_utils.py index 7fa549b57cb..c0b6ab39e84 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer_utils.py +++ b/backends/vulkan/quantizer/vulkan_quantizer_utils.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch from torch.fx import Node @@ -27,9 +27,26 @@ "OP_TO_ANNOTATOR", "propagate_annotation", "_convert_scalars_to_attrs", + "bits_to_range", ] +def bits_to_range(bits: int) -> Tuple[int, int]: + """ + Calculate quantization range for given number of bits. + + Args: + bits: Number of quantization bits + + Returns: + Tuple of (qmin, qmax) for the given bit width + """ + return ( + -(2 ** (bits - 1)), + (2 ** (bits - 1) - 1), + ) + + AnnotatorType = Callable[ [ torch.fx.GraphModule, diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index d6d27d2e3a3..cfe5baa9c1d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,59 +9,67 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// Calculate scale and zero point from min and max values -void calculate_scale_and_zero_point( - float min_val, - float max_val, - int qmin, - int qmax, - float eps_threshold, - out float scale_val, - out int zero_point_val) { - // ensure we have zero included in our range - min_val = min(min_val, 0.0); - max_val = max(max_val, 0.0); +// mapping_type : 0 = ASYM, 1 = SYM, 2 = SYM_NO_CLIP +void calc_scale_zp( + float lo, float hi, + int qmin, int qmax, + int mapping_type, + float eps, + out float scale, out int zp) { + // Handle case where lo and hi are +/-INF (no valid values found) + if (isinf(lo) || isinf(hi)) { + lo = 0.0; + hi = 0.0; + } - scale_val = (max_val - min_val) / float(qmax - qmin); + float minv = min(lo, 0.0); + float maxv = max(hi, 0.0); - // Handle zero or very small scale - if (scale_val == 0.0 || isinf(1.0 / scale_val)) { - scale_val = 0.1; - } + if (mapping_type == 0) { // asymmetric + scale = (maxv - minv) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale == 0.0 || isinf(1.0/scale)) { + scale = eps; + } - // Cut off small scale using the provided eps threshold - if (scale_val < eps_threshold) { - float org_scale = scale_val; - scale_val = eps_threshold; + if (scale < eps) { + float org_scale = scale; + scale = eps; - // Adjust min and max based on new scale - if (min_val == 0.0) { - max_val = eps_threshold * float(qmax - qmin); - } else if (max_val == 0.0) { - min_val = -eps_threshold * float(qmax - qmin); - } else { - float amplifier = eps_threshold / org_scale; - min_val *= amplifier; - max_val *= amplifier; + // Adjust min and max based on new scale to maintain proper quantization range + if (minv == 0.0) { + maxv = eps * float(qmax - qmin); + } else if (maxv == 0.0) { + minv = -eps * float(qmax - qmin); + } else { + float amplifier = eps / org_scale; + minv *= amplifier; + maxv *= amplifier; + } + } + + // Calculate zero_point (matching reference implementation) + float initial_zero_point = float(qmin) - round(minv / scale); + zp = int(clamp(initial_zero_point, float(qmin), float(qmax))); + } else { // symmetric -- centred + float scale_sym; + if (mapping_type == 1) { // SYM + float M = max(abs(minv), abs(maxv)); + scale_sym = M / (float(qmax - qmin) * 0.5); + } else { // SYM_NO_CLIP + float smin = abs(minv) / max(abs(float(qmin)), 1.0); // Avoid division by zero + float smax = maxv / max(float(qmax), 1.0); // Avoid division by zero + scale_sym = max(smin, smax); } - } - // Calculate zero point - float zero_point_from_min = float(qmin) - min_val / scale_val; - float zero_point_from_max = float(qmax) - max_val / scale_val; - float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); - float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); - float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; + // Handle zero or very small scale + if (scale_sym == 0.0 || isinf(1.0/scale_sym)) { + scale_sym = eps; + } - // Nudge zero point to integer - if (initial_zero_point < float(qmin)) { - zero_point_val = qmin; - } else if (initial_zero_point > float(qmax)) { - zero_point_val = qmax; - } else { - zero_point_val = int(round(initial_zero_point)); + scale = max(scale_sym, eps); + zp = int((qmax + qmin + 1) >> 1); // mid-point – always fits } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index 48681a46c30..99a64c3589e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -31,12 +31,22 @@ $if MODE == "per_tensor": int quant_max; float eps; }; -$else: +$if MODE == "per_token": layout(push_constant) uniform restrict Block { int num_tokens; int quant_min; int quant_max; }; +$if MODE == "block_wise": + layout(push_constant) uniform BlockPC { + ivec4 blockSize; // WHCN (>=1) + ivec4 numBlocks; // #blocks along W,H,C,N + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP + int quant_min; + int quant_max; + float eps; + }; ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} ${layout_declare_ubo(B, "ivec4", "t_in_strides")} @@ -57,68 +67,133 @@ shared float shared_min[NWORKERS]; shared float shared_max[NWORKERS]; /* - * QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE) - * - * This shader computes quantization parameters (scale and zero_point) for converting - * floating-point tensors to n-bit integer representations while preserving the - * original data range as much as possible. - * - * ALGORITHM: - * 1. Find global min/max values across tensor elements using parallel reduction - * 2. Use tree reduction with shared memory for efficient min/max computation - * 3. Calculate scale = (max - min) / (quant_max - quant_min) - * 4. Calculate zero_point to map floating-point zero to integer value - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor) - * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) - * - Per-Token Mode: - * - Global WG Size: {num_tokens, 1, 1} (one workgroup per token) - * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) - * - * SUPPORTED CONFIGURATIONS: - * - Buffer Storage: Uses simple linear indexing through buffer elements - * - No axis mapping or packing considerations - processes elements sequentially - * - Works with any tensor layout since it accesses buffer data linearly - * - * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: - * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: - * - * Initial shared_min/shared_max arrays populated by each thread: - * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - * - * Stride 1 (compare pairs, keep min/max): - * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - * Active: | 0 | | 2 | | 4 | | 6 | | - * - * Stride 2 (compare pairs, keep min/max): - * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) - * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - * Active: | 0 | | | | 4 | | | | - * - * Stride 4 (final comparison): - * shared_min: | 0 | | | | | | | | (min(0,0) = 0) - * shared_max: | 10 | | | | | | | | (max(10,5) = 10) - * Active: | 0 | | | | | | | | - * - * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - * - * PER-TENSOR QUANTIZATION: - * - Single workgroup processes entire tensor with strided access - * - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...] - * - Tree reduction combines all thread results into global min/max - * - Output: Single scale and zero_point values - * - * PER-TOKEN QUANTIZATION: - * - Multiple workgroups, each processing one token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Each workgroup finds min/max within its assigned token - * - Output: Array of scale and zero_point values (one per token) - */ + Quantization Parameter Computation Shader (Buffer Storage) + This shader computes quantization parameters (scale and zero_point) for converting + floating-point tensors to n-bit integer representations while preserving the + original data range as much as possible. The computed parameters enable efficient + quantization by mapping the continuous floating-point range to discrete integer values. + + Important Considerations: + (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + + Workgroup Configuration: + - choose_qparams_per_tensor + This mode computes a single set of quantization parameters for the entire tensor. + Uses parallel reduction across all threads to find global min/max values. + + (*) global_wg_size: {1, 1, 1} (single workgroup processes entire tensor) + (*) local_wg_size: {64, 1, 1} (matches NWORKERS for shared memory) + + - choose_qparams_per_token + This mode computes separate quantization parameters for each token in the tensor. + Each workgroup processes one token independently to find token-specific min/max. + + (*) global_wg_size: {num_tokens, 1, 1} (one workgroup per token) + (*) local_wg_size: {1, 1, 1} (single thread per token) + + - choose_qparams_block_wise + This mode computes quantization parameters for each block of elements, allowing + fine-grained control over quantization granularity within the tensor. Each block + is processed independently to find its own min/max values and compute corresponding + scale and zero_point parameters. + + (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) + (*) local_wg_size: {1, 1, 1} (single thread per block) + + Block-wise quantization supports multiple mapping types for scale/zero_point calculation: + + - mapping_type = 0 (ASYMMETRIC): + Uses asymmetric quantization where the full floating-point range [min, max] is + mapped to the quantized range [quant_min, quant_max]. This preserves the original + data distribution but may not center zero optimally. + + Calculation: + scale = (max - min) / (quant_max - quant_min) + zero_point = quant_min - round(min / scale) + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + scale = (10.2 - (-3.5)) / (7 - (-8)) = 13.7 / 15 = 0.913 + zero_point = -8 - round(-3.5 / 0.913) = -8 - (-4) = -4 + + - mapping_type = 1 (SYMMETRIC): + Uses symmetric quantization where the range is centered around zero. The scale + is computed based on the maximum absolute value, ensuring zero is exactly + representable in the quantized domain. + + Calculation: + max_abs = max(abs(min), abs(max)) + scale = max_abs / ((quant_max - quant_min) / 2) + zero_point = (quant_max + quant_min + 1) / 2 // midpoint + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + max_abs = max(3.5, 10.2) = 10.2 + scale = 10.2 / ((7 - (-8)) / 2) = 10.2 / 7.5 = 1.36 + zero_point = (-8 + 7 + 1) / 2 = 0 + + - mapping_type = 2 (SYMMETRIC_NO_CLIPPING_ERR): + A variant of symmetric quantization that minimizes clipping errors by computing + separate scales for positive and negative ranges, then using the maximum. This + reduces quantization error on the dominant range while ensuring no values are + clipped. + + Calculation: + smin = abs(min) / abs(quant_min) // scale for negative range + smax = max / quant_max // scale for positive range + scale = max(smin, smax) // use larger scale to avoid clipping + zero_point = (quant_max + quant_min + 1) / 2 // midpoint + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + smin = 3.5 / 8 = 0.4375 + smax = 10.2 / 7 = 1.457 + scale = max(0.4375, 1.457) = 1.457 // use smax to avoid clipping positives + zero_point = (-8 + 7 + 1) / 2 = 0 + + Tree Reduction Algorithm for Min/Max Finding: + The shader uses a parallel tree reduction algorithm to efficiently find minimum and + maximum values across multiple threads. This approach reduces the number of memory + accesses and synchronization points compared to sequential scanning. + + Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: + + Step 1 - Initial Population: + Each thread loads its assigned value into shared memory arrays. + shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + + Step 2 - Stride 1 (Compare Adjacent Pairs): + Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. + shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + Active: | 0 | | 2 | | 4 | | 6 | | + + Step 3 - Stride 2 (Compare Pairs of Pairs): + Threads 0,4 compare with threads 2,6 respectively. + shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) + shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + Active: | 0 | | | | 4 | | | | + + Step 4 - Stride 4 (Final Comparison): + Thread 0 compares with thread 4 to get final result. + shared_min: | 0 | | | | | | | | (min(1,0) = 0) + shared_max: | 10 | | | | | | | | (max(10,5) = 10) + Active: | 0 | | | | | | | | + + Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + + The tree reduction completes in log_2(N) steps where N is the number of threads, + providing O(log N) time complexity instead of O(N) for sequential reduction. + + Quantization Parameter Calculation: + Once min/max values are determined, the shader computes: + - scale = (max - min) / (quant_max - quant_min) + - zero_point = quantization offset to map floating-point zero to integer range + + Mode-Specific Behavior: + - Per-Tensor: Single workgroup with strided access across entire tensor + - Per-Token: Multiple workgroups, each processing one token independently + - Block-Wise: Each thread processes assigned blocks using nested loops over block dimensions +*/ #ifdef per_tensor @@ -176,99 +251,141 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); + // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant + calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; } } -#else +#elif defined(per_token) void choose_qparams_per_token() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_workgroups = gl_NumWorkGroups.x; - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); uint token_size = total_elements / uint(num_tokens); - // Calculate how many tokens each workgroup should process - // This handles the case where we have more tokens than workgroups - uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; - - // Calculate which tokens this workgroup is responsible for - uint start_token = group_id * tokens_per_workgroup; - uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + const uint TOTAL_TOKENS = uint(num_tokens); - // Early exit if this workgroup has no tokens to process - if (start_token >= uint(num_tokens)) { - return; - } - - // Process each token assigned to this workgroup - for (uint token_id = start_token; token_id < end_token; token_id++) { + /* each invocation handles token-ids: id, id+STRIDE, id+2·STRIDE … */ + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + for (uint token_id = gl_GlobalInvocationID.x; token_id < TOTAL_TOKENS; token_id += STRIDE) { // Calculate the start and end indices for this token uint token_start = token_id * token_size; uint token_end = token_start + token_size; - // Each thread processes multiple elements within the token with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity + // Each thread processes the entire token + float lo = 1.0/0.0; // +INF + float hi = -1.0/0.0; // -INF bool found_valid = false; - // Process elements within this token only - for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + // Process all elements in this token + for (uint i = token_start; i < token_end; i++) { float val = t_in[i]; if (!isnan(val) && !isinf(val)) { if (!found_valid) { - thread_min = val; - thread_max = val; + lo = hi = val; found_valid = true; } else { - thread_min = min(thread_min, val); - thread_max = max(thread_max, val); + lo = min(lo, val); + hi = max(hi, val); } } } - // Intra-group reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); + if (!found_valid) { + // If no valid values were found, use default values + lo = 0.0; + hi = 0.0; + } - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; + // Calculate scale and zero point directly + float scale_val; + int zero_point_val; + // Use default values: mapping_type=0 (ASYMMETRIC), eps=1e-5 + calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; + // Write results + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } +} + +#elif defined(block_wise) + +ivec4 block_id_to_coord(uint bid) { + ivec4 bc; + bc.w = int(bid) / blockStride.w; + + int r = int(bid) - bc.w * blockStride.w; + bc.z = r / blockStride.z; + + r -= bc.z * blockStride.z; + bc.y = r / blockStride.y; + + r -= bc.y * blockStride.y; + bc.x = r; + return bc; +} + +void choose_qparams_block_wise() { + const uint TOTAL_BLOCKS = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); + + // each invocation handles block-ids: id, id+STRIDE, id+2·STRIDE + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE) { + // block -> WHCN coordinate + ivec4 bc = block_id_to_coord(block_id); + ivec4 blockStart = bc * blockSize; // first element (inclusive) + ivec4 blockEnd = blockStart + blockSize; // last element (exclusive) + + // min / max scan over the block + float lo = 1.0/0.0; // +INF + float hi = -1.0/0.0; // -INF + bool found_valid = false; + + // Calculate actual block dimensions + ivec4 actualBlockSize = blockEnd - blockStart; + int blockElements = actualBlockSize.x * actualBlockSize.y * actualBlockSize.z * actualBlockSize.w; + + // Linear iteration over block elements + for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { + // Convert linear index to 4D coordinates within block + int remaining = elemIdx; + int dn = remaining / (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); + remaining -= dn * (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); + int dc = remaining / (actualBlockSize.x * actualBlockSize.y); + remaining -= dc * (actualBlockSize.x * actualBlockSize.y); + int dh = remaining / actualBlockSize.x; + int dw = remaining - dh * actualBlockSize.x; + + ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); + uint idx = tidx_to_bufi(tidx, t_in_strides); + float v = t_in[idx]; + + if (!isnan(v) && !isinf(v)) { + if (!found_valid) { + lo = hi = v; + found_valid = true; + } else { + lo = min(lo, v); + hi = max(hi, v); } } - barrier(); } - // Final calculation for this token - if (local_id == 0) { - float token_min = shared_min[0]; - float token_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); - - t_scale[token_id] = scale_val; - t_zero_point[token_id] = zero_point_val; + // Handle the case where no valid values were found in the block + if (!found_valid) { + lo = 0.0; + hi = 0.0; } - // Synchronize before processing next token - barrier(); + float scale; + int zp; + calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp); + + t_zero_point[block_id] = zp; + t_scale[block_id] = scale; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml index c37039f68e9..ee900750e16 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -10,3 +10,5 @@ choose_qparams_buffer: MODE: per_tensor - NAME: choose_qparams_per_token_asymmetric_buffer MODE: per_token + - NAME: choose_qparams_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 5076b2d68e9..62ea7099f8c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -22,8 +22,13 @@ ${define_required_extensions(IN_DTYPE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} -${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +$if MODE != "block_wise": + ${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} + ${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +$else: + ${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": @@ -32,16 +37,33 @@ $if MODE == "per_tensor": int quant_max; float eps; }; -$else: +$if MODE == "per_token": layout(push_constant) uniform restrict Block { int num_tokens; int quant_min; int quant_max; }; +$if MODE == "block_wise": + layout(push_constant) uniform BlockPC { + ivec4 blockSize; // WHCN (>=1) + ivec4 numBlocks; // #blocks along W,H,C,N + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP + int quant_min; + int quant_max; + float eps; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_scale_limits")} -${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} +$if MODE != "block_wise": + ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} + ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} +$else: + ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} + #include "indexing_utils.h" #include "choose_qparams.glslh" @@ -54,73 +76,87 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; shared float shared_min[NWORKERS]; shared float shared_max[NWORKERS]; -/* - * QUANTIZATION PARAMETER COMPUTATION SHADER (TEXTURE STORAGE) - * - * This shader computes quantization parameters (scale and zero_point) for converting - * floating-point tensors to n-bit integer representations while preserving the - * original data range as much as possible. - * - * ALGORITHM: - * 1. Find global min/max values across tensor elements using parallel reduction - * 2. Use tree reduction with shared memory for efficient min/max computation - * 3. Calculate scale = (max - min) / (quant_max - quant_min) - * 4. Calculate zero_point to map floating-point zero to integer value - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: Default (typically {num_elements, 1, 1}) - * - Local WG Size: Default (typically {64, 1, 1}) - * - Per-Token Mode: - * - Global WG Size: Default (typically based on tensor dimensions) - * - Local WG Size: Default (typically {64, 1, 1}, or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with linear texel iteration - * - Assumes width-packed layout (packed_dim = 0) in current implementation - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - Note: Axis mapping support depends on indexing utilities - * - * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: - * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: - * - * Initial shared_min/shared_max arrays populated by each thread: - * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - * - * Stride 1 (compare pairs, keep min/max): - * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - * Active: | 0 | | 2 | | 4 | | 6 | | - * - * Stride 2 (compare pairs, keep min/max): - * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) - * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - * Active: | 0 | | | | 4 | | | | - * - * Stride 4 (final comparison): - * shared_min: | 0 | | | | | | | | (min(0,0) = 0) - * shared_max: | 10 | | | | | | | | (max(10,5) = 10) - * Active: | 0 | | | | | | | | - * - * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - * - * PER-TENSOR QUANTIZATION: - * - Single workgroup processes entire tensor - * - Each thread processes multiple texels with stride - * - Thread 0: texels [0, 64, 128, ...] -> elements [0-3, 256-259, 512-515, ...] - * - Thread 1: texels [1, 65, 129, ...] -> elements [4-7, 260-263, 516-519, ...] - * - Tree reduction combines all thread results into global min/max - * - Output: Single scale and zero_point values - * - * PER-TOKEN QUANTIZATION: - * - Multiple workgroups, each processing subset of tokens - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Each workgroup processes multiple tokens if num_tokens > num_workgroups - * - Within each token, threads process texels containing token elements - * - Output: Array of scale and zero_point values (one per token) - */ +/*/* + Quantization Parameter Computation Shader (Buffer Storage) + This shader computes quantization parameters (scale and zero_point) for converting + floating-point tensors to n-bit integer representations while preserving the + original data range as much as possible. The computed parameters enable efficient + quantization by mapping the continuous floating-point range to discrete integer values. + + Important Considerations: + (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + + Workgroup Configuration: + - choose_qparams_per_tensor + This mode computes a single set of quantization parameters for the entire tensor. + Uses parallel reduction across all threads to find global min/max values. + + (*) global_wg_size: default + (*) local_wg_size: default + + - choose_qparams_per_token + This mode computes separate quantization parameters for each token in the tensor. + Each workgroup processes one token independently to find token-specific min/max. + + (*) global_wg_size: default + (*) local_wg_size: {1, 1, 1} + + - choose_qparams_block_wise + This mode computes quantization parameters for each block of elements, allowing + fine-grained control over quantization granularity within the tensor. Each block + is processed independently to find its own min/max values and compute corresponding + scale and zero_point parameters. + + NOTE: This mode currently only supports buffer storage for the output. + + (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) + (*) local_wg_size: {1, 1, 1} (single thread per block) + + Tree Reduction Algorithm for Min/Max Finding: + The shader uses a parallel tree reduction algorithm to efficiently find minimum and + maximum values across multiple threads. This approach reduces the number of memory + accesses and synchronization points compared to sequential scanning. + + Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: + + Step 1 - Initial Population: + Each thread loads its assigned value into shared memory arrays. + shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + + Step 2 - Stride 1 (Compare Adjacent Pairs): + Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. + shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + Active: | 0 | | 2 | | 4 | | 6 | | + + Step 3 - Stride 2 (Compare Pairs of Pairs): + Threads 0,4 compare with threads 2,6 respectively. + shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) + shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + Active: | 0 | | | | 4 | | | | + + Step 4 - Stride 4 (Final Comparison): + Thread 0 compares with thread 4 to get final result. + shared_min: | 0 | | | | | | | | (min(1,0) = 0) + shared_max: | 10 | | | | | | | | (max(10,5) = 10) + Active: | 0 | | | | | | | | + + Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + + The tree reduction completes in log_2(N) steps where N is the number of threads, + providing O(log N) time complexity instead of O(N) for sequential reduction. + + Quantization Parameter Calculation: + Once min/max values are determined, the shader computes: + - scale = (max - min) / (quant_max - quant_min) + - zero_point = quantization offset to map floating-point zero to integer range + + Mode-Specific Behavior: + - Per-Tensor: Single workgroup with strided access across entire tensor + - Per-Token: Multiple workgroups, each processing one token independently +*/ #ifdef per_tensor @@ -235,14 +271,14 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); + calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); } } -#else +#elif defined(per_token) void choose_qparams_per_token() { // Each token is processed by multiple workgroups for parallel reduction @@ -373,7 +409,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); + calc_scale_zp(token_min, token_max, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions @@ -392,6 +428,100 @@ void choose_qparams_per_token() { } } +#elif defined(block_wise) + +ivec4 block_id_to_coord(uint bid) { + ivec4 bc; + bc.w = int(bid) / blockStride.w; + + int r = int(bid) - bc.w * blockStride.w; + bc.z = r / blockStride.z; + + r -= bc.z * blockStride.z; + bc.y = r / blockStride.y; + + r -= bc.y * blockStride.y; + bc.x = r; + return bc; +} + +void choose_qparams_block_wise() { + const uint T = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + + // tensor full size in WHCN order + const ivec4 tensorSz = blockSize * numBlocks; + + // Process blocks with stride for better parallelization + for (uint blkIdx = gl_GlobalInvocationID.x; blkIdx < T; blkIdx += STRIDE) { + // block index in WHCN + const ivec4 b4d = block_id_to_coord(blkIdx); + const ivec4 blockStart = b4d * blockSize; + const ivec4 blockEnd = blockStart + blockSize; + + // scan all elements inside the block + float vmin = 3.402823e38; // +FLT_MAX + float vmax = -3.402823e38; // -FLT_MAX + bool found_valid = false; + + // Calculate total elements in block for linear iteration + const int blockElements = blockSize.x * blockSize.y * blockSize.z * blockSize.w; + + // Linear iteration over block elements (more cache-friendly) + for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { + // Convert linear index to 4D coordinates within block + int remaining = elemIdx; + int dn = remaining / (blockSize.x * blockSize.y * blockSize.z); + remaining -= dn * (blockSize.x * blockSize.y * blockSize.z); + int dc = remaining / (blockSize.x * blockSize.y); + remaining -= dc * (blockSize.x * blockSize.y); + int dh = remaining / blockSize.x; + int dw = remaining - dh * blockSize.x; + + ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); + + // skip padding when tensor size is not an exact multiple of block + if (any(greaterThanEqual(tidx, tensorSz))) { continue; } + + // tensor index -> (x,y,z,component) inside input texture + ivec4 posi = to_texture_elem_pos(tidx, tensorSz, 0); // 0 = W_DIM (width packed) + + // fetch texel and pick the element inside it + FVEC4_T texl = load_texel(t_in, posi.xyz); + float v; + if (posi.w == 0) v = texl.x; + else if (posi.w == 1) v = texl.y; + else if (posi.w == 2) v = texl.z; + else v = texl.w; + + if (!isnan(v) && !isinf(v)) { + if (!found_valid) { + vmin = vmax = v; + found_valid = true; + } else { + vmin = min(vmin, v); + vmax = max(vmax, v); + } + } + } + + // Handle case where no valid values were found + if (!found_valid) { + vmin = 0.0; + vmax = 0.0; + } + + // compute scale / zero‑point (same maths as buffer kernel) + float scale; + int zp; + calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); + + // Write the scalar values directly to buffer using linear index + t_scale[blkIdx] = scale; + t_zero_point[blkIdx] = zp; + } +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml index f3961b87a0f..a097ce0da48 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -10,3 +10,5 @@ choose_qparams_texture: MODE: per_tensor - NAME: choose_qparams_per_token_asymmetric_texture3d MODE: per_token + - NAME: choose_qparams_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index 94072dfbfea..43e62eadeee 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -53,6 +53,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -71,68 +82,60 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); /* - * DEQUANTIZATION SHADER (BUFFER STORAGE) - * - * This shader converts n-bit integer tensor values back to floating-point representations - * using pre-computed quantization parameters (scale and zero_point). The dequantization - * reconstructs the original floating-point values from their discrete integer representations - * with minimal precision loss. - * - * ALGORITHM: - * 1. Load quantized integer value from buffer - * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale - * 3. Store reconstructed floating-point value to output buffer - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access - * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering - * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping - * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) - * - * DEQUANTIZATION FORMULA VISUALIZATION: - * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: - * - * Integer Domain: Floating Point Domain: - * quant_min ──────────────► min_val - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * quant_max ──────────────► max_val - * - * Dequantization Process: - * Input: -103 (int8) - * Step 1: qvalue - zero_point = -103 - (-128) = 25 - * Step 2: result * scale = 25 * 0.1 = 2.5 - * Output: 2.5 (float) - * - * PER-TENSOR DEQUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All elements use same dequantization parameters - * - Parameters passed as push constants for efficiency - * - Formula: value = (qvalue - zero_point) * scale - * - * PER-TOKEN DEQUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates its token_id from tensor coordinates - * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] - * - * Token ID calculation for element at tensor index (w, z, y, x): - * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y - * - 3D tensor: token_id = z * sizes.y + y - * - 2D tensor: token_id = y - * - 1D tensor: token_id = 0 - */ + Dequantization Shader (Buffer Storage) + This shader converts n-bit integer tensor values back to floating-point representations + using pre-computed quantization parameters (scale and zero_point). The dequantization + reconstructs the original floating-point values from their discrete integer representations + with minimal precision loss. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - dequantize_per_tensor + This mode reverses the uniform quantization applied across the entire tensor by using the + single scale and zero_point values to convert quantized integer values back to their original + floating-point representation. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_per_token + This mode reverses the quantization applied individually to each token (or element) in the + input by using separate scale and zero_point values for each token. For a tensor of shape + [B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting + quantized values back to their original floating-point representation for each group of H + elements independently. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_per_channel + This mode reverses the quantization applied separately to each channel of the input tensor + by using distinct scale and zero_point values for each channel. For a tensor of shape + [B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C + channels, converting quantized values back to their original floating-point representation + independently for each channel. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_block_wise + This mode reverses the block-wise quantization applied to groups of elements by using separate + scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the + inverse affine transformation per block to convert quantized values back to their original + floating-point representation. For example, if the tensor shape is [6, 9, 4] and + blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements, + and dequantization is performed independently on each block. + + (*) global_wg_size: default + (*) local_wg_size: default + + Dequantization Formula: + value = (qvalue - zero_point) * scale +*/ #ifdef per_tensor @@ -187,7 +190,7 @@ void dequantize_per_token() { t_out[out_bufi] = value; } -#else // per_channel +#elif defined(per_channel) void dequantize_per_channel() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -226,6 +229,29 @@ void dequantize_per_channel() { t_out[out_bufi] = value; } +#else // block_wise + +void dequantize_block_wise() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + const ivec4 bcoord = out_tidx / blockSize; + + const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + + t_out[out_bufi] = value; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index b9a53217452..999c59d3b79 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -19,3 +19,5 @@ dequantize_buffer: MODE: per_token - NAME: dequantize_per_channel_buffer MODE: per_channel + - NAME: dequantize_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index 5c978c61846..20bf6c87e26 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -56,6 +56,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -201,7 +212,7 @@ void dequantize_per_token() { write_texel(t_out, pos, outtex); } -#else // per_channel +#elif defined(per_channel) void dequantize_per_channel() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -292,6 +303,39 @@ void dequantize_per_channel() { write_texel(t_out, pos, outtex); } +#else // block_wise + +void dequantize_block_wise() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) + return; + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); + int foldedZ = pos.z; + + int C_total = numBlocks.z * blockSize.z; + + [[unroll]] for (int i = 0; i < 4; ++i) { + ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); + + ivec4 bcoord = tidx / blockSize; + int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 88ccc6e3274..9b624762192 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -19,3 +19,5 @@ dequantize_texture: MODE: per_token - NAME: dequantize_per_channel_texture3d MODE: per_channel + - NAME: dequantize_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl index 9834a539667..9a342d8e057 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -53,6 +53,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -71,64 +82,54 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); /* - * QUANTIZATION SHADER (BUFFER STORAGE) - * - * This shader converts floating-point tensor values to n-bit integer representations - * using pre-computed quantization parameters (scale and zero_point). The quantization - * maps floating-point values to a discrete integer range while preserving the - * original data distribution as much as possible. - * - * ALGORITHM: - * 1. Load floating-point input value from buffer - * 2. Apply quantization formula: qvalue = round(value / scale) + zero_point - * 3. Clamp result to [quant_min, quant_max] range - * 4. Store quantized integer value to output buffer - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access - * - and supports any tensor layout through stride calculations and dimension ordering - * - Per-Token Config: Assumes width-packed layout (packed_dim = 0) - * - since that is how token index is calculated - * - * QUANTIZATION FORMULA VISUALIZATION: - * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: - * - * Floating Point Domain: Integer Domain: - * min_val ────────────────► quant_min - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * max_val ────────────────► quant_max - * - * Quantization Process: - * Input: 2.5 (float) - * Step 1: value / scale = 2.5 / 0.1 = 25.0 - * Step 2: round(25.0) + zero_point = 25 + (-128) = -103 - * Step 3: clamp(-103, -128, 127) = -103 - * Output: -103 (int8) - * - * PER-TENSOR QUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All elements use same quantization parameters - * - Parameters passed as push constants for efficiency - * - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max) - * - * PER-TOKEN QUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates its token_id from tensor coordinates - * - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max) - */ + Quantization Shader (Buffer Storage) + This shader converts floating-point tensor values to n-bit integer representations + using pre-computed quantization parameters (scale and zero_point). The quantization + maps floating-point values to a discrete integer range while preserving the original + data distribution as much as possible. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - quantize_per_tensor + This mode applies uniform quantization across the entire tensor using a single scale + and zero_point value. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_token + This mode applies quantization individually to each token (or element) in the input, + using separate scale and zero_point values for each token. For instance if we have + a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_channel + This mode applies quantization separately to each channel of the input tensor, using + distinct scale and zero_point values for each channel. For example, if the tensor shape + is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing + each channel to be quantized independently. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_block_wise + This mode applies quantization in blocks or groups of elements, allowing different scale + and zero_point values for each block. It is equivalent to quantize_affine, where quantization + parameters are affine transformations applied per block. For example, if the tensor shape + is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. + + (*) global_wg_size: default + (*) local_wg_size: default + + Quantization Formula: + qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). +*/ #ifdef per_tensor @@ -183,7 +184,7 @@ void quantize_per_token() { t_out[out_bufi] = qvalue; } -#else // per_channel +#elif defined(per_channel) void quantize_per_channel() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -222,6 +223,29 @@ void quantize_per_channel() { t_out[out_bufi] = qvalue; } +#else // block_wise + +void quantize_block_wise() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + const ivec4 bcoord = out_tidx / blockSize; + + const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + const OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]); + + t_out[out_bufi] = qvalue; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 1dd8e6e2ffe..5b479c2f90f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -19,3 +19,5 @@ quantize_buffer: MODE: per_token - NAME: quantize_per_channel_buffer MODE: per_channel + - NAME: quantize_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl index 148fa85eb2b..69f219ef329 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -58,6 +58,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict BlockPC { + ivec4 blockSize; // WHCN + ivec4 numBlocks; // (#W,#H,#C,#N) + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -70,68 +81,58 @@ ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* - * QUANTIZATION SHADER (TEXTURE STORAGE) - * - * This shader converts floating-point tensor values to n-bit integer representations - * using pre-computed quantization parameters (scale and zero_point). The quantization - * maps floating-point values to a discrete integer range while preserving the - * original data distribution as much as possible. - * - * ALGORITHM: - * 1. Load floating-point texel (4 values) from 3D texture - * 2. Apply quantization formula to each component: qvalue = round(value / scale) + zero_point - * 3. Clamp each result to [quant_min, quant_max] range - * 4. Store quantized integer texel to output texture - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with texel-based processing - * - Assumes width-packed layout (packed_dim = 0) in current implementation - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - For per-token mode: scale/zero_point tensors must use buffer storage - * - * QUANTIZATION FORMULA VISUALIZATION: - * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: - * - * Floating Point Domain: Integer Domain: - * min_val ────────────────► quant_min - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * max_val ────────────────► quant_max - * - * Texel Quantization Process: - * Input Texel: [2.5, -1.0, 0.5, 3.2] (float4) - * Per-component quantization with scale=0.1, zero_point=-128: - * Component 0: round(2.5 / 0.1) + (-128) = 25 + (-128) = -103 - * Component 1: round(-1.0 / 0.1) + (-128) = -10 + (-128) = -138 → clamp to -128 - * Component 2: round(0.5 / 0.1) + (-128) = 5 + (-128) = -123 - * Component 3: round(3.2 / 0.1) + (-128) = 32 + (-128) = -96 - * Output Texel: [-103, -128, -123, -96] (int4) - * - * PER-TENSOR QUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All texel components use same quantization parameters - * - Parameters passed as push constants for efficiency - * - Each thread processes one texel (4 elements) independently - * - Formula: qvalue[i] = clamp(round(value[i] / scale) + zero_point, quant_min, quant_max) - * - * PER-TOKEN QUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates token_id from its 3D texture position - * - Scale/zero_point buffers accessed directly (not as textures) - * - Formula: qvalue[i] = clamp(round(value[i] / scale[token_id]) + zero_point[token_id], quant_min, quant_max) - */ + Quantization Shader (Texture Storage) + This shader converts floating-point tensor values to n-bit integer representations + using pre-computed quantization parameters (scale and zero_point). The quantization + maps floating-point values to a discrete integer range while preserving the original + data distribution as much as possible. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - quantize_per_tensor + This mode applies uniform quantization across the entire tensor using a single scale + and zero_point value. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_token + This mode applies quantization individually to each token (or element) in the input, + using separate scale and zero_point values for each token. For instance if we have + a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. + + (*) global_wg_size: default + (*) local_wg_size: default + + - quantize_per_channel + This mode applies quantization separately to each channel of the input tensor, using + distinct scale and zero_point values for each channel. For example, if the tensor shape + is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing + each channel to be quantized independently. + + (*) global_wg_size: default + (*) local_wg_size: Default with special handling for batch dimension. When quantizing along + the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, + uses standard workgroup size derived from global workgroup dimensions. + + - quantize_block_wise + This mode applies quantization in blocks or groups of elements, allowing different scale + and zero_point values for each block. It is equivalent to quantize_affine, where quantization + parameters are affine transformations applied per block. For example, if the tensor shape + is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. + + (*) global_wg_size: default + (*) local_wg_size: Default with special handling for batch dimension. When quantizing along + the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, + uses standard workgroup size derived from global workgroup dimensions. + + Quantization Formula: + qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). +*/ #ifdef per_tensor @@ -192,7 +193,7 @@ void quantize_per_token() { write_texel(t_out, pos, outtex); } -#else // per_channel +#elif defined(per_channel) void quantize_per_channel() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -270,6 +271,36 @@ void quantize_per_channel() { write_texel(t_out, pos, outtex); } +#else // block_wise + +void quantize_block_wise() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) + return; + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); + int foldedZ = pos.z; + + int C_total = numBlocks.z * blockSize.z; + + [[unroll]] for (int i = 0; i < 4; ++i) { + ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); + + ivec4 bcoord = tidx / blockSize; + int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]); + outtex[i] = qvalue; + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 47e532be8b9..2e40ac90794 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -19,3 +19,5 @@ quantize_texture: MODE: per_token - NAME: quantize_per_channel_texture3d MODE: per_channel + - NAME: quantize_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index de269920eea..76d352334e3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -14,45 +14,6 @@ namespace vkcompute { -namespace { - -void resize_choose_qparams_tensor_output( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - const ValueRef scale_out = args.at(0).refs.at(0); - const ValueRef zero_point_out = args.at(0).refs.at(1); - - // Both scale and zero_point are scalar tensors for per-tensor quantization - // Since we use single workgroup approach, no extra buffer space needed - graph->virtual_resize(scale_out, {}); - graph->virtual_resize(zero_point_out, {}); -} - -void resize_choose_qparams_per_token_output( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - const ValueRef scale_out = args.at(0).refs.at(0); - const ValueRef zero_point_out = args.at(0).refs.at(1); - const ValueRef input = args.at(1).refs.at(0); - - // Calculate output sizes for scale and zero_point tensors - const auto input_sizes = graph->sizes_of(input); - std::vector output_sizes; - output_sizes.reserve(input_sizes.size() - 1); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - output_sizes.push_back(input_sizes[i]); - } - output_sizes.push_back(1); - - graph->virtual_resize(scale_out, output_sizes); - graph->virtual_resize(zero_point_out, output_sizes); -} - -// Custom workgroup size pickers for ChooseQParams operations utils::uvec3 choose_qparams_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -135,15 +96,67 @@ utils::uvec3 choose_qparams_per_token_pick_local_wg_size( const ValueRef input = args.at(1).refs.at(0); if (graph->is_buffer_storage(input)) { - // For buffer storage, use 64 threads in X dimension to match NWORKERS - return {64u, 1u, 1u}; + return {1u, 1u, 1u}; } else { // For texture storage, use the default logic return graph->create_local_wg_size(global_workgroup_size); } } -} // namespace +utils::uvec3 choose_qparams_block_wise_pick_global_wg_size( + ComputeGraph* g, + const vkapi::ShaderInfo&, + const std::vector& a, + const std::vector& r) { + const ValueRef input = a.at(2).refs.at(0); + const auto blkRef = r.at(0); + const auto inSz = g->sizes_of(input); + const auto blkList = g->get_int_list(blkRef); + + // Use same code as in add_choose_qparams_block_wise_node + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*blkList); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(inSz); + + // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) + utils::ivec4 nBlk = { + (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], + (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], + (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], + (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; + + uint32_t nBlocks = nBlk[0] * nBlk[1] * nBlk[2] * nBlk[3]; + + // For texture storage, use more threads to better utilize GPU parallelism + // Each thread can process multiple blocks with stride + if (g->is_buffer_storage(input)) { + return {nBlocks, 1u, 1u}; + } else { + // For texture storage, use more workgroups to better utilize GPU + // Aim for ~64-256 threads per workgroup for good occupancy + uint32_t preferred_threads_per_wg = 64; + uint32_t num_workgroups = + (nBlocks + preferred_threads_per_wg - 1) / preferred_threads_per_wg; + num_workgroups = std::max(1u, std::min(num_workgroups, nBlocks)); + return {num_workgroups * preferred_threads_per_wg, 1u, 1u}; + } +} + +utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( + ComputeGraph* g, + const vkapi::ShaderInfo&, + const utils::uvec3& global_wg_size, + const std::vector& a, + const std::vector&) { + const ValueRef input = a.at(2).refs.at(0); + + if (g->is_buffer_storage(input)) { + return {1u, 1u, 1u}; + } else { + // For texture storage, use 64 threads per workgroup for better occupancy + uint32_t local_size = std::min(64u, global_wg_size[0]); + return {local_size, 1u, 1u}; + } +} void add_choose_qparams_tensor_node( ComputeGraph& graph, @@ -162,6 +175,7 @@ void add_choose_qparams_tensor_node( float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; + std::vector push_constants; if (graph.is_buffer_storage(input)) { param_ubos = { @@ -178,7 +192,6 @@ void add_choose_qparams_tensor_node( graph.logical_limits_ubo(zero_point_out)}; } - std::vector push_constants; push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), @@ -203,7 +216,7 @@ void add_choose_qparams_tensor_node( // Resize Args {}, // Resizing Logic - resize_choose_qparams_tensor_output)); + nullptr)); } void add_choose_qparams_per_token_asymmetric_node( @@ -227,6 +240,7 @@ void add_choose_qparams_per_token_asymmetric_node( int quant_max_val = 127; // Fixed for asymmetric quantization vkapi::ParamsBindList param_ubos; + std::vector push_constants; if (graph.is_buffer_storage(input)) { param_ubos = { @@ -243,7 +257,6 @@ void add_choose_qparams_per_token_asymmetric_node( graph.logical_limits_ubo(zero_point_out)}; } - std::vector push_constants; push_constants = { PushConstantDataInfo(&num_tokens_val, sizeof(int)), PushConstantDataInfo(&quant_min_val, sizeof(int)), @@ -268,7 +281,100 @@ void add_choose_qparams_per_token_asymmetric_node( // Resize Args {}, // Resizing Logic - resize_choose_qparams_per_token_output)); + nullptr)); +} + +void add_choose_qparams_block_wise_node( + ComputeGraph& graph, + ValueRef input, + ValueRef block_size, + int mapping_type, // 0 / 1 / 2 + ValueRef quant_min, + ValueRef quant_max, + ValueRef eps, + ValueRef scale_out, + ValueRef zp_out) { + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // For shader compatibility, we still need to convert to WHCN order + // but the output shape calculation is now handled correctly in resize + // function + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) + utils::ivec4 num_blocks_vec = { + (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], + (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], + (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], + (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + int qmin = static_cast(graph.get_int(quant_min)); + int qmax = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); + + // Create push constants vector + std::vector push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&mapping_type, sizeof(int)), + PushConstantDataInfo(&qmin, sizeof(int)), + PushConstantDataInfo(&qmax, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float))}; + + std::string kernel_name("choose_qparams_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zp_out), + graph.strides_ubo(zp_out)}; + } else { + // For texture input, the shader uses buffer storage for outputs + // so we need buffer UBOs for the output tensors + param_ubos = { + graph.logical_limits_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zp_out), + graph.strides_ubo(zp_out)}; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_block_wise_pick_global_wg_size, + choose_qparams_block_wise_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zp_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {block_size}, + // Resizing Logic + nullptr)); } void choose_qparams_tensor_impl( @@ -278,9 +384,8 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) - const ValueRef dtype = - args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef eps = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; ValueRef scale_out = kDummyValueRef; @@ -301,17 +406,11 @@ void choose_qparams_tensor_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -327,8 +426,7 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef dtype = - args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; ValueRef scale_out = kDummyValueRef; @@ -349,17 +447,16 @@ void choose_qparams_per_token_asymmetric_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); @@ -370,9 +467,8 @@ void choose_qparams_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef mapping_type = args[arg_idx++]; + const ValueRef block_size = args[arg_idx++]; const ValueRef target_dtype = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; @@ -382,7 +478,6 @@ void choose_qparams_affine_impl( const ValueRef out_tuple_ref = args[arg_idx++]; // Suppress unused variable warnings - (void)mapping_type; (void)target_dtype; (void)scale_dtype; (void)zero_point_dtype; @@ -402,36 +497,42 @@ void choose_qparams_affine_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); - } - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + std::string mapping_type_str = graph.get_string(mapping_type); + int mapping_type_val = 0; // Default to ASYMMETRIC + + if (mapping_type_str == "ASYMMETRIC") { + mapping_type_val = 0; + } else if (mapping_type_str == "SYMMETRIC") { + mapping_type_val = 1; + } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { + mapping_type_val = 2; } - // Default to per-tensor quantization parameter calculation for TorchAO affine - // ops - add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); + add_choose_qparams_block_wise_node( + graph, + input, + block_size, + mapping_type_val, + quant_min, + quant_max, + eps, + scale_out, + zero_point_out); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 7edb9b2f70d..61fd76145a4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -17,38 +17,59 @@ namespace vkcompute { -void resize_dequantize_output( +void resize_dequantize_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { (void)extra_args; - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(in)); + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + out->virtual_resize(in->sizes()); } -utils::uvec3 dequantize_per_channel_global_wg_size( +utils::uvec3 dequantize_per_channel_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { + (void)args; (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + const ValueRef input = args.at(1).refs.at(0); - return global_wg_size; + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. We need to ensure that we dispatch the correct + // number of workgroups in the Z dimension to cover all batch-channel + // combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { + local_wg_size[2] = 1; + } + + return local_wg_size; } -utils::uvec3 dequantize_per_channel_local_wg_size( +utils::uvec3 dequantize_block_wise_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - (void)args; + (void)shader; (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); utils::uvec3 local_wg_size = @@ -56,16 +77,17 @@ utils::uvec3 dequantize_per_channel_local_wg_size( // WORKAROUND: The CommandBuffer::dispatch function divides // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel dequantization along the batch - // axis, we need to ensure that we dispatch the correct number of workgroups - // in the Z dimension to cover all batch-channel combinations. + // workgroups to dispatch. We need to ensure that we dispatch the correct + // number of workgroups in the Z dimension to cover all batch-channel + // combinations. // // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], // local_wg_size[2]) might reduce the number of workgroups dispatched. To // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, // we set local_wg_size[2] = 1. const auto input_sizes = graph->sizes_of(input); - if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { local_wg_size[2] = 1; } @@ -131,7 +153,7 @@ void add_dequantize_per_tensor_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); } void add_dequantize_per_token_node( @@ -161,25 +183,18 @@ void add_dequantize_per_token_node( graph.sizes_ubo(input), graph.strides_ubo(input), graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.strides_ubo(output)}; } else { param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; } + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -203,7 +218,7 @@ void add_dequantize_per_token_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); } void add_dequantize_per_channel_node( @@ -252,27 +267,19 @@ void add_dequantize_per_channel_node( graph.sizes_ubo(input), graph.strides_ubo(input), graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.strides_ubo(output)}; } else { param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; } + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -281,7 +288,7 @@ void add_dequantize_per_channel_node( graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - dequantize_per_channel_global_wg_size, + default_pick_global_wg_size, dequantize_per_channel_local_wg_size, // Inputs and Outputs {{output, vkapi::kWrite}, @@ -296,7 +303,94 @@ void add_dequantize_per_channel_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); +} + +void add_dequantize_block_wise_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& block_size, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // Convert dimensions to WHCN order for shader + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) + utils::ivec4 num_blocks_vec = { + tensor_size_whcn[0] / block_size_vec[0], + tensor_size_whcn[1] / block_size_vec[1], + tensor_size_whcn[2] / block_size_vec[2], + tensor_size_whcn[3] / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + dequantize_block_wise_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_node)); } void dequantize_per_tensor_impl( @@ -308,31 +402,39 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is an integer type VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } add_dequantize_per_tensor_node( graph, input, scale, zero_point, quant_min, quant_max, output); @@ -347,12 +449,11 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; @@ -366,15 +467,8 @@ void dequantize_per_token_impl( VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); - // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -430,12 +524,11 @@ void dequantize_per_channel_impl( const ValueRef axis = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; @@ -449,15 +542,8 @@ void dequantize_per_channel_impl( VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); - // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -513,8 +599,7 @@ void dequantize_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef block_size = args[arg_idx++]; const ValueRef scale = args[arg_idx++]; const ValueRef zero_point = args[arg_idx++]; const ValueRef input_dtype = args[arg_idx++]; @@ -529,33 +614,61 @@ void dequantize_affine_impl( // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is an integer type VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization + // Verify block_size is valid (each dimension must divide evenly into input + // size) const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + if ((*block_size_list)[i] > 1) { + VK_CHECK_COND( + input_sizes[i] % (*block_size_list)[i] == 0, + "Input size at dimension ", + i, + " (", + input_sizes[i], + ") must be divisible by block_size at dimension ", + i, + " (", + (*block_size_list)[i], + ")"); + } } - // Default to per-tensor dequantization for TorchAO affine ops - add_dequantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); + add_dequantize_block_wise_node( + graph, + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + output); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index d786981e1fc..92719505a0f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -17,40 +17,60 @@ namespace vkcompute { -void resize_quantize_output( +void resize_quantize_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { (void)extra_args; - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(in)); + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + out->virtual_resize(in->sizes()); } -utils::uvec3 quantize_per_channel_global_wg_size( +utils::uvec3 quantize_per_channel_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { (void)shader; + (void)args; (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + const ValueRef input = args.at(1).refs.at(0); + + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel quantization along the batch axis, + // we need to ensure that we dispatch the correct number of workgroups in the + // Z dimension to cover all batch-channel combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { + local_wg_size[2] = 1; + } - return global_wg_size; + return local_wg_size; } -utils::uvec3 quantize_per_channel_local_wg_size( +utils::uvec3 quantize_block_wise_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { (void)shader; - (void)args; (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); utils::uvec3 local_wg_size = @@ -67,7 +87,8 @@ utils::uvec3 quantize_per_channel_local_wg_size( // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, // we set local_wg_size[2] = 1. const auto input_sizes = graph->sizes_of(input); - if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { local_wg_size[2] = 1; } @@ -133,7 +154,7 @@ void add_quantize_per_tensor_node( // Resize Args {}, // Resizing Logic - resize_quantize_output)); + resize_quantize_node)); } void add_quantize_per_token_node( @@ -205,7 +226,7 @@ void add_quantize_per_token_node( // Resize Args {}, // Resizing Logic - resize_quantize_output)); + resize_quantize_node)); } void add_quantize_per_channel_node( @@ -283,7 +304,7 @@ void add_quantize_per_channel_node( graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - quantize_per_channel_global_wg_size, + default_pick_global_wg_size, quantize_per_channel_local_wg_size, // Inputs and Outputs {{output, vkapi::kWrite}, @@ -298,7 +319,94 @@ void add_quantize_per_channel_node( // Resize Args {}, // Resizing Logic - resize_quantize_output)); + resize_quantize_node)); +} + +void add_quantize_block_wise_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& block_size, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // Convert PyTorch dimensions to WHCN order for shader + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) + utils::ivec4 num_blocks_vec = { + tensor_size_whcn[0] / block_size_vec[0], + tensor_size_whcn[1] / block_size_vec[1], + tensor_size_whcn[2] / block_size_vec[2], + tensor_size_whcn[3] / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + quantize_block_wise_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_node)); } void quantize_per_tensor_impl( @@ -310,7 +418,7 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warning - dtype is inferred from output @@ -339,7 +447,7 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warning - dtype is inferred from output @@ -412,7 +520,7 @@ void quantize_per_channel_impl( const ValueRef axis = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warning - dtype is inferred from output @@ -485,8 +593,7 @@ void quantize_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef block_size = args[arg_idx++]; const ValueRef scale = args[arg_idx++]; const ValueRef zero_point = args[arg_idx++]; const ValueRef output_dtype = args[arg_idx++]; @@ -499,6 +606,8 @@ void quantize_affine_impl( // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is a floating point type @@ -507,18 +616,51 @@ void quantize_affine_impl( graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Verify block_size is valid (each dimension must divide evenly into input + // size) const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + if ((*block_size_list)[i] > 1) { + VK_CHECK_COND( + input_sizes[i] % (*block_size_list)[i] == 0, + "Input size at dimension ", + i, + " (", + input_sizes[i], + ") must be divisible by block_size at dimension ", + i, + " (", + (*block_size_list)[i], + ")"); + } } - // Default to per-tensor quantization for TorchAO affine ops - add_quantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); + add_quantize_block_wise_node( + graph, + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + output); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp index a47c58b7ef6..728d38c3e2d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp @@ -68,11 +68,10 @@ void check_linear_qta8a_qga4w_args( const auto mat1_scale_sizes = graph.sizes_of(mat1_scale); const auto mat1_zero_point_sizes = graph.sizes_of(mat1_zero_point); - VK_CHECK_COND(mat1_scale_sizes.size() == 1); - VK_CHECK_COND(mat1_zero_point_sizes.size() == 1); - - VK_CHECK_COND(mat1_scale_sizes[0] == input_num_tokens); - VK_CHECK_COND(mat1_zero_point_sizes[0] == input_num_tokens); + VK_CHECK_COND( + utils::val_at(-1, mat1_scale_sizes) == input_num_tokens); + VK_CHECK_COND( + utils::val_at(-1, mat1_zero_point_sizes) == input_num_tokens); // Verify weight scales and zeros have the same shape const auto weight_scales_sizes = graph.sizes_of(weight_scales); diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 7f535a0001b..ef429ff21fa 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -35,6 +35,7 @@ python_unittest( "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/vulkan:vulkan_preprocess", + "//pytorch/ao:torchao", # @manual ] ) diff --git a/backends/vulkan/test/op_tests/quantize_affine_test.cpp b/backends/vulkan/test/op_tests/quantize_affine_test.cpp new file mode 100644 index 00000000000..d2a971da82b --- /dev/null +++ b/backends/vulkan/test/op_tests/quantize_affine_test.cpp @@ -0,0 +1,1379 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include "test_utils.h" + +#include +#include +#include + +static inline void +_check_dims(c10::string_view name, int64_t expected, int64_t actual) { + VK_CHECK_COND( + expected == actual, + name, + " has rank ", + actual, + " but block_size has length ", + expected); +} + +at::Tensor quantize_affine_reference_impl( + const at::Tensor& input_, + const std::vector& block_size, + const at::Tensor& scale, + const c10::optional& zero_point_opt, + int64_t quant_min, + int64_t quant_max, + at::ScalarType out_dtype, + c10::optional zero_point_domain_opt = std::string("INT")) { + constexpr float kEps = 1e-7f; + + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kFloat || input_.scalar_type() == at::kHalf || + input_.scalar_type() == at::kBFloat16, + "Unsupported input dtype: ", + input_.dtype()); + + auto zero_point_domain = + zero_point_domain_opt.has_value() ? *zero_point_domain_opt : "INT"; + + bool has_zp = zero_point_opt.has_value(); + VK_CHECK_COND( + has_zp || zero_point_domain == "NONE" || zero_point_domain == "", + "zero_point must be supplied unless zero_point_domain is NONE or null"); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor scale_b = + scale.view(shape_after_reduction).to(input_reshaped.scalar_type()); + + at::Tensor zp_b; + if (has_zp) { + zp_b = (*zero_point_opt).view(shape_after_reduction).toType(at::kFloat); + } + + scale_b = scale_b.clamp_min(kEps); + at::Tensor inv_scale = 1.0f / scale_b; + + at::Tensor q; + if (zero_point_domain == "INT") { + VK_CHECK_COND(has_zp, "INT zero_point_domain requires zero_point tensor"); + q = at::round(input_reshaped * inv_scale) + zp_b; + } else if (zero_point_domain == "NONE" || zero_point_domain.empty()) { + VK_CHECK_COND( + !has_zp, "zero_point must be None when domain is NONE / null"); + q = at::round(input_reshaped * inv_scale); + } else { + VK_CHECK_COND( + has_zp && zero_point_domain == "FLOAT", + "zero_point_domain must be INT, FLOAT, NONE or null"); + const float mid_point = (quant_max + quant_min + 1) * 0.5f; + at::Tensor min_val = zp_b - scale_b * mid_point; + q = at::round((input_reshaped - min_val) / scale_b); + } + + q = at::clamp(q, (double)quant_min, (double)quant_max); + + q = q.view(in_sizes).to(out_dtype); + + return q; +} + +at::Tensor dequantize_affine_reference_impl( + const at::Tensor& input_, + const std::vector& block_size, + const at::Tensor& scale, + const c10::optional& zero_point_opt, + int64_t quant_min, + int64_t quant_max, + at::ScalarType out_dtype, + c10::optional zero_point_domain_opt = std::string("INT")) { + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kByte || input_.scalar_type() == at::kChar || + input_.scalar_type() == at::kShort || + input_.scalar_type() == at::kInt, + "Unsupported input dtype: ", + input_.dtype()); + + VK_CHECK_COND( + out_dtype == at::kFloat || out_dtype == at::kHalf || + out_dtype == at::kBFloat16, + "Unsupported output dtype: ", + out_dtype); + + auto zero_point_domain = + zero_point_domain_opt.has_value() ? *zero_point_domain_opt : "INT"; + + bool has_zp = zero_point_opt.has_value(); + VK_CHECK_COND( + has_zp || zero_point_domain == "NONE" || zero_point_domain == "", + "zero_point must be supplied unless zero_point_domain is NONE or null"); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor scale_b = scale.view(shape_after_reduction).to(out_dtype); + + at::Tensor zp_b; + if (has_zp) { + zp_b = (*zero_point_opt).view(shape_after_reduction).to(out_dtype); + } + + at::Tensor input_fp = input_reshaped.to(out_dtype); + at::Tensor dq; + + if (zero_point_domain == "INT") { + VK_CHECK_COND(has_zp, "INT zero_point_domain requires zero_point tensor"); + dq = (input_fp - zp_b) * scale_b; + } else if (zero_point_domain == "NONE" || zero_point_domain.empty()) { + VK_CHECK_COND( + !has_zp, "zero_point must be None when domain is NONE / null"); + dq = input_fp * scale_b; + } else { + VK_CHECK_COND( + has_zp && zero_point_domain == "FLOAT", + "zero_point_domain must be INT, FLOAT, NONE or null"); + const float mid_point = (quant_max + quant_min + 1) * 0.5f; + at::Tensor min_val = zp_b - scale_b * mid_point; + dq = input_fp * scale_b + min_val; + } + + dq = dq.view(in_sizes); + + return dq; +} + +// Wrapper function to maintain compatibility with existing test code (above is +// a good reference for how the python implementation works) +at::Tensor quantize_affine_reference_impl( + const at::Tensor& input, + const std::vector& block_size, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + return quantize_affine_reference_impl( + input, + block_size, + scale, + c10::optional(zero_point), + quant_min, + quant_max, + dtype, + std::string("INT")); +} + +// Wrapper function for dequantize_affine +at::Tensor dequantize_affine_reference_impl( + const at::Tensor& input, + const std::vector& block_size, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + return dequantize_affine_reference_impl( + input, + block_size, + scale, + c10::optional(zero_point), + quant_min, + quant_max, + dtype, + std::string("INT")); +} + +std::tuple choose_qparams_affine_reference_impl( + const at::Tensor& input_, + const std::string& mapping_type, + const std::vector& block_size, + int64_t quant_min, + int64_t quant_max, + double eps) { + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kFloat || input_.scalar_type() == at::kHalf || + input_.scalar_type() == at::kBFloat16, + "Unsupported input dtype: ", + input_.dtype()); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor min_val = input_reshaped.amin(reduction_dims, /*keepdim=*/true); + at::Tensor max_val = input_reshaped.amax(reduction_dims, /*keepdim=*/true); + + at::Tensor scale, zero_point; + + if (mapping_type == "ASYMMETRIC") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate scale + scale = (max_val - min_val) / (quant_max - quant_min); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point + zero_point = at::round(quant_min - min_val / scale); + zero_point = at::clamp(zero_point, quant_min, quant_max); + } else if (mapping_type == "SYMMETRIC") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate max absolute value + at::Tensor abs_min = at::abs(min_val); + at::Tensor abs_max = at::abs(max_val); + at::Tensor M = at::maximum(abs_min, abs_max); + + // Calculate scale + scale = M / ((quant_max - quant_min) * 0.5); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point (mid-point) + zero_point = + at::full_like(scale, (quant_max + quant_min + 1) / 2, at::kInt); + } else if (mapping_type == "SYMMETRIC_NO_CLIPPING_ERR") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate scale based on min/max values + at::Tensor s_min = at::abs(min_val) / std::abs(quant_min); + at::Tensor s_max = max_val / quant_max; + scale = at::maximum(s_min, s_max); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point (mid-point) + zero_point = + at::full_like(scale, (quant_max + quant_min + 1) / 2, at::kInt); + } else { + VK_CHECK_COND( + false, + "Unsupported mapping_type: ", + mapping_type, + ". Expected ASYMMETRIC, SYMMETRIC, or SYMMETRIC_NO_CLIPPING_ERR"); + } + + std::vector output_shape; + for (size_t i = 0; i < shape_after_reduction.size(); ++i) { + if (shape_after_reduction[i] != 1 || + std::find(reduction_dims.begin(), reduction_dims.end(), i) == + reduction_dims.end()) { + output_shape.push_back(shape_after_reduction[i]); + } + } + + // Reshape scale and zero_point to final output shape + scale = scale.view(output_shape); + zero_point = zero_point.view(output_shape); + + return std::make_tuple(scale, zero_point); +} + +void test_vulkan_quantize_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Get reference output + at::Tensor reference_out = quantize_affine_reference_impl( + input, + block_size, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(dtype)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("torchao.quantize_affine.default") + (graph, + { + r_input.value, + r_block_size, + r_scale.value, + r_zero_point.value, + r_output_dtype, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // Copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + if (!output_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " scales: ["; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << scales[i] << (i < scales.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " zero_points: ["; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << zero_points[i] << (i < zero_points.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl << input << std::endl; + std::cout << "reference:" << std::endl << reference_int << std::endl; + std::cout << "vulkan:" << std::endl << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_quantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +TEST(VulkanQuantizeAffineTest, test_1d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 1D: 1x1x1x12 Tensor, block_size is 3 + test_vulkan_quantize_affine( + {12}, // input_sizes + {3}, // block_size + {0.1f, 0.2f, 0.15f, 0.25f}, // scales (4 blocks) + {10, -20, 5, 30}, // zero_points (4 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizeAffineTest, test_2d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 2D: 1x1x8x6 Tensor, block_size is 1x1x2x3 (8/2=4, 6/3=2, so 4*2=8 blocks) + test_vulkan_quantize_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size (1/1=1, 1/1=1, 8/2=4, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f}, // scales (8 blocks) + {-10, 15, 0, 25, -5, 20, 10, -15}, // zero_points (8 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizeAffineTest, test_3d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 3D: 1x6x4x6 Tensor, block_size is 3x2x2 (6/3=2, 4/2=2, 6/2=3, so 2*2*3=12 + // blocks) + test_vulkan_quantize_affine( + {6, 4, 6}, // input_sizes (changed 7->6 so divisible by 3) + {3, + 2, + 2}, // block_size (6 divisible by 3, 4 divisible by 2, 6 divisible by 2) + {0.1f, + 0.2f, + 0.15f, + 0.25f, + 0.3f, + 0.05f, + 0.4f, + 0.35f, + 0.12f, + 0.18f, + 0.22f, + 0.28f}, // scales (12 blocks) + {-15, 10, 5, -25, 20, -10, 15, -5, 8, -12, 18, -8}, // zero_points (12 + // blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizeAffineTest, test_4d_quantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 4D: 8x6x6x6 Tensor, block_size is 2x3x2x3 (8/2=4, 6/3=2, 6/2=3, 6/3=2, so + // 4*2*3*2=48 blocks) + test_vulkan_quantize_affine( + {8, 6, 6, 6}, // input_sizes + {2, 3, 2, 3}, // block_size (8/2=4, 6/3=2, 6/2=3, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f, 0.12f, 0.18f, + 0.22f, 0.28f, 0.32f, 0.08f, 0.45f, 0.38f, 0.14f, 0.24f, 0.16f, 0.26f, + 0.34f, 0.06f, 0.44f, 0.36f, 0.11f, 0.21f, 0.13f, 0.23f, 0.31f, 0.07f, + 0.41f, 0.37f, 0.19f, 0.29f, 0.17f, 0.27f, 0.33f, 0.09f, 0.43f, 0.39f, + 0.10f, 0.20f, 0.14f, 0.24f, 0.30f, 0.04f, 0.40f, 0.34f}, // scales (48 + // blocks) + {-20, 10, 5, -15, 25, -10, 15, -5, 8, -12, 18, -8, 22, + -18, 12, -22, -25, 15, 0, -20, 30, -5, 20, -10, 5, -25, + 10, -15, 35, -15, 25, -35, -30, 20, -5, -25, 40, 0, 30, + -40, 10, -30, 15, -10, 45, -20, 35, -45}, // zero_points (48 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kFloat, // input dtype + at::kChar); // output dtype +} + +void test_vulkan_dequantize_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kChar, + at::ScalarType out_dtype = at::kFloat, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + // Create input tensor with random integer values within quant_min and + // quant_max + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = at::randint( + quant_min, + quant_max + 1, + input_sizes_int64, + at::device(at::kCPU).dtype(in_dtype)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Get reference output + at::Tensor reference_out = dequantize_affine_reference_impl( + input, + block_size, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + out_dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Create block_size as IntList instead of Tensor + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + // Create input_dtype scalar + const ValueRef r_input_dtype = + graph.add_scalar(static_cast(in_dtype)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(out_dtype)); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + // Match the argument order in dequantize_affine_impl in Dequantize.cpp: + // input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, + // output_dtype, output + VK_GET_OP_FN("torchao.dequantize_affine.default") + (graph, + { + r_input.value, + r_block_size, + r_scale.value, + r_zero_point.value, + r_input_dtype, + r_quant_min, + r_quant_max, + r_output_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // Copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + if (!output_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " scales: ["; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << scales[i] << (i < scales.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " zero_points: ["; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << zero_points[i] << (i < zero_points.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl << input << std::endl; + std::cout << "reference:" << std::endl << reference_out << std::endl; + std::cout << "vulkan:" << std::endl << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kChar, + at::ScalarType out_dtype = at::kFloat) { + // Test with buffer storage + test_vulkan_dequantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +TEST(VulkanDequantizeAffineTest, test_1d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 1D: 1x1x1x12 Tensor, block_size is 3 + test_vulkan_dequantize_affine( + {12}, // input_sizes + {3}, // block_size + {0.1f, 0.2f, 0.15f, 0.25f}, // scales (4 blocks) + {10, -20, 5, 30}, // zero_points (4 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_2d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 2D: 1x1x8x6 Tensor, block_size is 1x1x2x3 (8/2=4, 6/3=2, so 4*2=8 blocks) + test_vulkan_dequantize_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size (1/1=1, 1/1=1, 8/2=4, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f}, // scales (8 blocks) + {-10, 15, 0, 25, -5, 20, 10, -15}, // zero_points (8 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_3d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 3D: 1x6x4x6 Tensor, block_size is 3x2x2 (6/3=2, 4/2=2, 6/2=3, so 2*2*3=12 + // blocks) + test_vulkan_dequantize_affine( + {6, 4, 6}, // input_sizes (changed 7->6 so divisible by 3) + {3, + 2, + 2}, // block_size (6 divisible by 3, 4 divisible by 2, 6 divisible by 2) + {0.1f, + 0.2f, + 0.15f, + 0.25f, + 0.3f, + 0.05f, + 0.4f, + 0.35f, + 0.12f, + 0.18f, + 0.22f, + 0.28f}, // scales (12 blocks) + {-15, 10, 5, -25, 20, -10, 15, -5, 8, -12, 18, -8}, // zero_points (12 + // blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_4d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 4D: 8x6x6x6 Tensor, block_size is 2x3x2x3 (8/2=4, 6/3=2, 6/2=3, 6/3=2, so + // 4*2*3*2=48 blocks) + test_vulkan_dequantize_affine( + {8, 6, 6, 6}, // input_sizes + {2, 3, 2, 3}, // block_size (8/2=4, 6/3=2, 6/2=3, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f, 0.12f, 0.18f, + 0.22f, 0.28f, 0.32f, 0.08f, 0.45f, 0.38f, 0.14f, 0.24f, 0.16f, 0.26f, + 0.34f, 0.06f, 0.44f, 0.36f, 0.11f, 0.21f, 0.13f, 0.23f, 0.31f, 0.07f, + 0.41f, 0.37f, 0.19f, 0.29f, 0.17f, 0.27f, 0.33f, 0.09f, 0.43f, 0.39f, + 0.10f, 0.20f, 0.14f, 0.24f, 0.30f, 0.04f, 0.40f, 0.34f}, // scales (48 + // blocks) + {-20, 10, 5, -15, 25, -10, 15, -5, 8, -12, 18, -8, 22, + -18, 12, -22, -25, 15, 0, -20, 30, -5, 20, -10, 5, -25, + 10, -15, 35, -15, 25, -35, -30, 20, -5, -25, 40, 0, 30, + -40, 10, -30, 15, -10, 45, -20, 35, -45}, // zero_points (48 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +void test_vulkan_choose_qparams_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::string& mapping_type, + int64_t quant_min, + int64_t quant_max, + double eps, + at::ScalarType in_dtype = at::kFloat, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kBuffer) { + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Get reference output + auto reference_out = choose_qparams_affine_reference_impl( + input, mapping_type, block_size, quant_min, quant_max, eps); + + at::Tensor reference_scale = std::get<0>(reference_out); + at::Tensor reference_zero_point = std::get<1>(reference_out); + + reference_zero_point = reference_zero_point.to(at::kInt); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Create mapping_type as string + std::string mapping_type_copy = mapping_type; + const ValueRef r_mapping_type = + graph.add_string(std::move(mapping_type_copy)); + + // Create block_size as IntList + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + // Create target_dtype, quant_min, quant_max, eps + const ValueRef r_target_dtype = + graph.add_scalar(static_cast(at::kChar)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_eps = graph.add_scalar(eps); + + // Create scale_dtype and zero_point_dtype + const ValueRef r_scale_dtype = + graph.add_scalar(static_cast(at::kFloat)); + const ValueRef r_zero_point_dtype = + graph.add_scalar(static_cast(at::kInt)); + + // Create output tuple + std::vector out_tuple; + + // Create scale and zero_point output tensors + const ValueRef r_scale_out = graph.add_tensor( + reference_scale.sizes().vec(), vkapi::kFloat, out_storage); + const ValueRef r_zero_point_out = graph.add_tensor( + reference_zero_point.sizes().vec(), vkapi::kInt, out_storage); + + out_tuple.push_back(r_scale_out); + out_tuple.push_back(r_zero_point_out); + + const ValueRef r_out_tuple = graph.add_value_list(std::move(out_tuple)); + + VK_GET_OP_FN("torchao.choose_qparams_affine.default") + (graph, + { + r_input.value, + r_mapping_type, + r_block_size, + r_target_dtype, + r_quant_min, + r_quant_max, + r_eps, + r_scale_dtype, + r_zero_point_dtype, + r_out_tuple, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale_out); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_scale = at::empty_like(reference_scale).contiguous(); + at::Tensor vk_zero_point = at::empty_like(reference_zero_point).contiguous(); + + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Compare outputs + const bool scale_correct = + at::allclose(reference_scale, vk_scale, /*rtol=*/1e-3, /*atol=*/1e-3); + + // For zero point, we need to compare as integers since zero point should be + // an integer First convert both tensors to int if they aren't already + at::Tensor ref_zp_int = reference_zero_point.to(at::kInt); + at::Tensor vk_zp_int = vk_zero_point.to(at::kInt); + const bool zero_point_correct = at::equal(ref_zp_int, vk_zp_int); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " mapping_type: " << mapping_type << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " eps: " << eps << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + if (!scale_correct || !zero_point_correct) { + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + + std::cout << "reference_scale:" << std::endl + << reference_scale << std::endl; + std::cout << "vulkan_scale:" << std::endl << vk_scale << std::endl; + + std::cout << "reference_zero_point:" << std::endl + << reference_zero_point << std::endl; + std::cout << "vulkan_zero_point:" << std::endl + << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct); + ASSERT_TRUE(zero_point_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::string& mapping_type, + int64_t quant_min, + int64_t quant_max, + double eps, + at::ScalarType in_dtype = at::kFloat) { + // Test with buffer storage for both input and output + test_vulkan_choose_qparams_affine_impl( + input_sizes, + block_size, + mapping_type, + quant_min, + quant_max, + eps, + in_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage for input and buffer storage for output + // (shader always uses buffer storage for outputs) + test_vulkan_choose_qparams_affine_impl( + input_sizes, + block_size, + mapping_type, + quant_min, + quant_max, + eps, + in_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kBuffer); +} + +TEST(VulkanChooseQParamsAffineTest, test_1d_asymmetric) { + // 1D: 12 Tensor, block_size is 3 + test_vulkan_choose_qparams_affine( + {12}, // input_sizes + {3}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_2d_symmetric) { + // 2D: 8x6 Tensor, block_size is 2x3 + test_vulkan_choose_qparams_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size + "SYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_3d_symmetric_no_clipping) { + // 3D: 6x4x6 Tensor, block_size is 3x2x2 + test_vulkan_choose_qparams_affine( + {6, 4, 6}, // input_sizes + {3, 2, 2}, // block_size + "SYMMETRIC_NO_CLIPPING_ERR", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_4d_asymmetric) { + // 4D: 4x6x6x6 Tensor, block_size is 2x3x2x3 + test_vulkan_choose_qparams_affine( + {4, 6, 6, 6}, // input_sizes (reduced from 8 to 4 to make test faster) + {2, 3, 2, 3}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_tensor) { + // Per-tensor: block_size equals tensor size + test_vulkan_choose_qparams_affine( + {4, 6, 8}, // input_sizes + {4, 6, 8}, // block_size equals tensor size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_token) { + // Per-token: block_size is all 1s except last dimension + test_vulkan_choose_qparams_affine( + {4, 6, 8}, // input_sizes + {1, 1, 8}, // block_size is all 1s except last dimension + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +// Additional tests for choose_qparams_affine + +TEST(VulkanChooseQParamsAffineTest, test_uint8_range) { + // Test with uint8 range (0-255) + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + 0, // quant_min (uint8 min) + 255, // quant_max (uint8 max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_int16_range) { + // Test with int16 range (-32768 to 32767) + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC", // mapping_type + -32768, // quant_min (int16 min) + 32767, // quant_max (int16 max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_larger_eps) { + // Test with larger epsilon value + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-2, // larger eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_channel_first_dim) { + // Per-channel quantization on first dimension + test_vulkan_choose_qparams_affine( + {8, 6, 4}, // input_sizes + {1, 6, 4}, // block_size (per-channel on dim 0) + "SYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_channel_middle_dim) { + // Per-channel quantization on middle dimension + test_vulkan_choose_qparams_affine( + {4, 8, 6}, // input_sizes + {4, 1, 6}, // block_size (per-channel on dim 1) + "SYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_mixed_block_sizes) { + // Mixed block sizes (some dimensions fully quantized, some partially) + test_vulkan_choose_qparams_affine( + {8, 6, 10}, // input_sizes + {4, 6, 2}, // block_size (mixed: partial, full, partial) + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_small_tensor) { + // Test with a small tensor + test_vulkan_choose_qparams_affine( + {2, 3}, // small input_sizes + {2, 3}, // block_size (full tensor) + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_asymmetric_narrow_range) { + // Test with a narrow quantization range + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_symmetric_narrow_range) { + // Test with a narrow quantization range with symmetric mapping + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_symmetric_no_clipping_narrow_range) { + // Test with a narrow quantization range with symmetric no clipping mapping + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC_NO_CLIPPING_ERR", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} \ No newline at end of file diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 9eac90ac33d..b9386f92772 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -216,3 +216,9 @@ def define_common_targets(is_fbcode = False): ":test_utils", ] ) + define_test_targets( + "quantize_affine_test", + extra_deps = [ + ":test_utils", + ] + ) diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index ff9e2d85a96..4f54bc638ba 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -7,7 +7,7 @@ from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_linear_weight_only_qcs_xnn_qconfig, + get_symmetric_quantization_config, VulkanQuantizer, ) @@ -16,6 +16,7 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightQuantizer from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import Quantizer @@ -101,7 +102,9 @@ def test_fuse_int8pack_mm(self): sample_inputs = model.get_sample_inputs() quantizer = VulkanQuantizer() - quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8)) + quantizer.set_global( + get_symmetric_quantization_config(is_dynamic=False, weight_bits=8) + ) edge_manager = quantize_and_lower_module( model, @@ -129,7 +132,9 @@ def test_fuse_linear_qcs4w(self): sample_inputs = model.get_sample_inputs() quantizer = VulkanQuantizer() - quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4)) + quantizer.set_global( + get_symmetric_quantization_config(is_dynamic=False, weight_bits=4) + ) edge_manager = quantize_and_lower_module( model, @@ -149,3 +154,56 @@ def test_fuse_linear_qcs4w(self): self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + + def test_fuse_linear_qta8a_qga4w(self): + """Test fusion of dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).""" + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + # Use source transform quantizer for dynamic activation + grouped weight quantization + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=128, # Group size for 4-bit weights + padding_allowed=False, + precision=torch.float32, + scales_precision=torch.float32, + device=torch.device("cpu"), + ) + + # Apply source transform quantization + quantized_model = quantizer.quantize(model) + + # Export the quantized model + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, + _check_ir_validity=False, + ) + + program = torch.export.export_for_training( + quantized_model, sample_inputs, strict=True + ).module() + + program = torch.export.export(program, sample_inputs) + + edge_manager = to_edge( + program, + compile_config=edge_compile_config, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + # Check that the linear_qta8a_qga4w operator was created + self.assertEqual(op_node_count(gm, "linear_qta8a_qga4w.default"), 1) + # Check that the original quantization/dequantization nodes were removed + self.assertEqual(op_node_count(gm, "quantize_per_token.default"), 0) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + self.assertEqual(op_node_count(gm, "linear.default"), 0) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index d71c0a35776..9086b2d0792 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -38,6 +38,14 @@ "dequantize_affine.default", } +_Q_OPS = { + "quantize_per_tensor.tensor", + "quantize_per_tensor.default", + "quantize_per_channel.default", + "quantize_per_token.default", + "quantize_affine.default", +} + ## ## Node type determination ## @@ -50,6 +58,13 @@ def is_dequant_node(node: torch.fx.Node) -> bool: return node_name in _DQ_OPS +def is_quant_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name in _Q_OPS + + def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index b94feb5a1ae..d87c722363f 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -12,7 +12,7 @@ import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, + get_symmetric_quantization_config as get_symmetric_quantization_config_xnnpack, XNNPACKQuantizer, ) @@ -127,11 +127,11 @@ def check_embedding_byte_registered(): "At the moment only per channel weight quantization is supported." ) if quant_params.quantize_linear.is_qc4: - operator_config_dynamic = get_symmetric_quantization_config( + operator_config_dynamic = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7 ) else: - operator_config_dynamic = get_symmetric_quantization_config( + operator_config_dynamic = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=True ) dynamic_quantizer.set_global(operator_config_dynamic) @@ -247,13 +247,13 @@ def get_coreml_quantizer(pt2e_quantize: str): raise NotImplementedError("4-bit Core ML quantizer is still under development") elif pt2e_quantize == "coreml_baseline_8a_c8w": - config = get_symmetric_quantization_config( + config = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=False ) quantizer = XNNPACKQuantizer().set_global(config) elif pt2e_quantize == "coreml_baseline_8a_c4w": - config = get_symmetric_quantization_config( + config = get_symmetric_quantization_config_xnnpack( is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7 ) quantizer = XNNPACKQuantizer().set_global(config) @@ -266,12 +266,14 @@ def get_coreml_quantizer(pt2e_quantize: str): def get_vulkan_quantizer(pt2e_quantize: str): from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_linear_weight_only_qcs_xnn_qconfig, + get_symmetric_quantization_config as get_symmetric_quantization_config_vulkan, VulkanQuantizer, ) if pt2e_quantize == "vulkan_8w": - config = get_linear_weight_only_qcs_xnn_qconfig(8) + config = get_symmetric_quantization_config_vulkan( + is_dynamic=False, weight_bits=8 + ) else: raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")