diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh new file mode 100644 index 00000000000..7194bebda35 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef DEQUANTIZE_GLSLH +#define DEQUANTIZE_GLSLH + +OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { + return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); +} + +#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl new file mode 100644 index 00000000000..2a1f62719a0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * DEQUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer value from buffer + * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access + * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering + * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping + * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Dequantization Process: + * Input: -103 (int8) + * Step 1: qvalue - zero_point = -103 - (-128) = 25 + * Step 2: result * scale = 25 * 0.1 = 2.5 + * Output: 2.5 (float) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: value = (qvalue - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for element at tensor index (w, z, y, x): + * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y + * - 3D tensor: token_id = z * sizes.y + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + OUT_T value = dequantize_val(qvalue, scale, zero_point); + + t_out[out_bufi] = value; +} + +#else + +void dequantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = value; +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml new file mode 100644 index 00000000000..4e434935356 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -0,0 +1,18 @@ +dequantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor_buffer + MODE: per_tensor + - NAME: dequantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl new file mode 100644 index 00000000000..cfc61dd1816 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * DEQUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer texel (4 values) from 3D texture + * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) for input/output textures + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * - Input/output textures: Must use standard axis mapping for per-token mode + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Texel Dequantization Process: + * Input Texel: [-103, -128, -123, -96] (int4) + * Per-component dequantization with scale=0.1, zero_point=-128: + * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 + * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 + * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 + * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 + * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: value[i] = (qvalue[i] - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for texel at position (x, y, z): + * - 3D tensor: token_id = z * texture_height + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale, zero_point); + outtex[i] = value; + } + write_texel(t_out, pos, outtex); +} + +#else + +void dequantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + FVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml new file mode 100644 index 00000000000..fc8c18468ed --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -0,0 +1,18 @@ +dequantize_texture: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor_texture3d + MODE: per_tensor + - NAME: dequantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp new file mode 100644 index 00000000000..77a51ce24f9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_dequantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_dequantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void add_dequantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void dequantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_dequantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); + VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 7b155c8f98b..1ec0602a4f2 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -20,6 +20,7 @@ #include "test_utils.h" #include +#include #include #include @@ -481,6 +482,8 @@ void test_reference_dequantize_per_tensor( std::cout << " zero_point: " << zero_point << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -598,8 +601,15 @@ void test_vulkan_dequantize_per_tensor_impl( graph.copy_from_staging( staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - // Compare outputs - const bool output_correct = at::allclose(reference_out, vk_out); + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } if (!output_correct) { std::cout << "\n" << "Failed with parameters: " << std::endl; @@ -611,6 +621,8 @@ void test_vulkan_dequantize_per_tensor_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -623,7 +635,6 @@ void test_vulkan_dequantize_per_tensor_impl( ASSERT_TRUE(output_correct); } -// Test cases for dequantize_per_tensor TEST( VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) { @@ -689,6 +700,99 @@ TEST( at::kHalf); // output dtype } +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {3, 4}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_float) { + test_vulkan_dequantize_per_tensor( + {2, 4, 3, 12}, // input sizes + 0.0001, // scale + 100, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scale to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + test_vulkan_dequantize_per_tensor( + {7}, // input sizes + 1e-5, // scale (much smaller to avoid overflow) + 5, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + void test_reference_dequantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -793,6 +897,8 @@ void test_reference_dequantize_per_token( std::cout << "" << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -894,9 +1000,15 @@ void test_vulkan_dequantize_per_token_impl( IOValueRef r_input = graph.add_input_tensor( input.sizes().vec(), from_at_scalartype(dtype), in_storage); IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); const ValueRef r_quant_min = graph.add_scalar(quant_min); const ValueRef r_quant_max = graph.add_scalar(quant_max); @@ -946,8 +1058,15 @@ void test_vulkan_dequantize_per_token_impl( graph.copy_from_staging( staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - // Compare outputs - const bool output_correct = at::allclose(reference_out, vk_out); + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } if (!output_correct) { std::cout << "\n" << "Failed with parameters: " << std::endl; @@ -967,6 +1086,8 @@ void test_vulkan_dequantize_per_token_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -979,7 +1100,6 @@ void test_vulkan_dequantize_per_token_impl( ASSERT_TRUE(output_correct); } -// Test cases for dequantize_per_token TEST( VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_uint8_to_float) { @@ -1059,3 +1179,112 @@ TEST( at::kInt, // input dtype at::kHalf); // output dtype } + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_vulkan_dequantize_per_token( + {2, 3, 6}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.0}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_float) { + std::vector scales = { + 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; + std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; + + test_vulkan_dequantize_per_token( + {2, 2, 2, 12}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.2}; + std::vector zero_points = {2, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scales to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + std::vector scales = {1e-5, 2e-5, 1.5e-5}; + std::vector zero_points = {20, -15, 1}; + + test_vulkan_dequantize_per_token( + {3, 6}, // input sizes (3 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +}