diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index 2a1f62719a0..faafa3fd266 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -42,6 +42,16 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -141,7 +151,7 @@ void dequantize_per_tensor() { t_out[out_bufi] = value; } -#else +#elif defined(per_token) void dequantize_per_token() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -176,6 +186,45 @@ void dequantize_per_token() { t_out[out_bufi] = value; } +#else // per_channel + +void dequantize_per_channel() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + // Calculate channel index based on the dequantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (tidx.x) + // axis 1 -> H dimension (tidx.y) + // axis 2 -> C dimension (tidx.z) + // axis 3 -> N dimension (tidx.w) + int channel_idx = 0; + + if (axis == 0) { + channel_idx = out_tidx.x; + } else if (axis == 1) { + channel_idx = out_tidx.y; + } else if (axis == 2) { + channel_idx = out_tidx.z; + } else if (axis == 3) { + channel_idx = out_tidx.w; + } + + channel_idx = min(channel_idx, num_channels - 1); + + OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]); + + t_out[out_bufi] = value; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index fb0d2ee61bf..b9a53217452 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -17,3 +17,5 @@ dequantize_buffer: MODE: per_tensor - NAME: dequantize_per_token_buffer MODE: per_token + - NAME: dequantize_per_channel_buffer + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index 801f4a2f6a2..ef3f5cca869 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -45,6 +45,16 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -147,7 +157,7 @@ void dequantize_per_tensor() { write_texel(t_out, pos, outtex); } -#else +#elif defined(per_token) void dequantize_per_token() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -189,6 +199,97 @@ void dequantize_per_token() { write_texel(t_out, pos, outtex); } +#else // per_channel + +void dequantize_per_channel() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + // Calculate channel index based on the dequantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (pos.x) + // axis 1 -> H dimension (pos.y) + // axis 2 -> C dimension (pos.z) + // axis 3 -> N dimension (batch folding in texture storage) + + if (axis == 0) { + // Width dimension - each texel component has different channel index + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + int channel_idx = pos.x * 4 + i; + channel_idx = min(channel_idx, num_channels - 1); + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } else if (axis == 1) { + int channel_idx = pos.y; + channel_idx = min(channel_idx, num_channels - 1); + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } else if (axis == 2) { + // Channel dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual channel index from the folded dimension + int folded_idx = pos.z; + int channel_idx = folded_idx % num_channels; + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } else if (axis == 3) { + // Batch dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual channel index from the folded dimension + int folded_idx = pos.z; + // In this case num_channels actually corresponds to the number of channels + // the C dimension N(C)HW + int channel_idx = folded_idx / num_channels; + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 7d19a543a03..88ccc6e3274 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -17,3 +17,5 @@ dequantize_texture: MODE: per_tensor - NAME: dequantize_per_token_texture3d MODE: per_token + - NAME: dequantize_per_channel_texture3d + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 3838da9a151..8845d6f6254 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -13,11 +13,10 @@ #include #include #include +#include namespace vkcompute { -namespace { - void resize_dequantize_output( ComputeGraph* graph, const std::vector& args, @@ -28,7 +27,50 @@ void resize_dequantize_output( graph->virtual_resize(out, graph->sizes_of(in)); } -} // namespace +utils::uvec3 dequantize_per_channel_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + + return global_wg_size; +} + +utils::uvec3 dequantize_per_channel_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel dequantization along the batch + // axis, we need to ensure that we dispatch the correct number of workgroups + // in the Z dimension to cover all batch-channel combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + local_wg_size[2] = 1; + } + + return local_wg_size; +} void add_dequantize_per_tensor_node( ComputeGraph& graph, @@ -171,6 +213,99 @@ void add_dequantize_per_token_node( resize_dequantize_output)); } +void add_dequantize_per_channel_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& axis, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_channel"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int axis_val = static_cast(graph.get_int(axis)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + // Normalize axis and convert from NCHW to WHCN using utility functions + const auto input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + // Normalize axis to handle negative indices + axis_val = normalize(axis_val, ndim); + + // Convert from NCHW axis to WHCN axis for shader (vulkan representation) + int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); + + int num_channels; + if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { + // For batch dimension dequantization in 4D tensors, pass the actual number + // of channels so the shader can correctly unfold the batch-channel folding + num_channels = static_cast(input_sizes[1]); // Channel dimension + } else { + num_channels = static_cast(input_sizes[axis_val]); + } + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + dequantize_per_channel_global_wg_size, + dequantize_per_channel_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + void dequantize_per_tensor_impl( ComputeGraph& graph, const std::vector& args) { @@ -292,6 +427,94 @@ void dequantize_per_token_impl( graph, input, scale, zero_point, quant_min, quant_max, output); } +void dequantize_per_channel_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef axis = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Normalize axis + int axis_val = static_cast(graph.get_int(axis)); + const auto input_sizes = graph.sizes_of(input); + int ndim = graph.dim_of(input); + if (axis_val < 0) { + axis_val += ndim; + } + + // Verify axis is valid + VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); + + // Get number of channels along the specified axis + int64_t num_channels = input_sizes[axis_val]; + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_channels + VK_CHECK_COND(scale_numel == num_channels); + VK_CHECK_COND(zero_point_numel == num_channels); + + add_dequantize_per_channel_node( + graph, input, scale, zero_point, axis, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_tensor.default, @@ -299,6 +522,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_channel.default, + dequantize_per_channel_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index f32a93e2b6a..cb9c04ee089 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -557,6 +557,18 @@ void test_vulkan_dequantize_per_token_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_dequantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_dequantize_per_tensor( const std::vector& input_sizes, @@ -637,6 +649,49 @@ void test_vulkan_dequantize_per_token( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_dequantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_dequantize_per_tensor( const std::vector& input_sizes, float scale, @@ -1684,6 +1739,214 @@ void test_reference_dequantize_per_channel( ASSERT_TRUE(output_correct); } +void test_vulkan_dequantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create random float tensor + at::Tensor float_x = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Map the dtype to the corresponding quantized type and quantize the float + // tensor + c10::ScalarType qtype; + at::Tensor adjusted_zero_points = zero_point_tensor; + + if (dtype == at::kByte) { + qtype = c10::kQUInt8; + // ATEN ONLY: Adjust zero points for unsigned types (must be non-negative) + adjusted_zero_points = at::clamp_min(zero_point_tensor, 0); + } else if (dtype == at::kChar) { + qtype = c10::kQInt8; + } else if (dtype == at::kInt) { + qtype = c10::kQInt32; + } else { + std::cout << "invalid dtype for ATEN: " << dtype << std::endl; + std::cout << " --> Delegating to c10::kQInt32" << std::endl; + qtype = c10::kQInt32; + } + + // Normalize axis for ATen (ATen doesn't handle negative axes in + // quantize_per_channel) + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes_int64.size(); + } + + // Quantize using ATen + at::Tensor quantized_aten = at::quantize_per_channel( + float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype); + + // Get ATen dequantized output + at::Tensor aten_out = at::dequantize(quantized_aten).to(out_dtype); + + // Extract the quantized values (int_repr) to use with our implementations + at::Tensor quantized_input = quantized_aten.int_repr().to(dtype); + + // Get reference output using + // torch::executor::native::dequantize_per_channel_aten + at::Tensor reference_out = + torch::executor::native::dequantize_per_channel_aten( + quantized_input, + scale_tensor.to(at::kDouble), + zero_point_tensor.to(at::kLong), + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Build Vulkan dequantize_per_channel graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + // Add tensors to graph + IOValueRef r_input = graph.add_input_tensor( + quantized_input.sizes().vec(), + from_at_scalartype(quantized_input.scalar_type()), + in_storage); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + + IOValueRef r_zero_point = graph.add_input_tensor( + adjusted_zero_points.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + ValueRef r_out = graph.add_tensor( + quantized_input.sizes().vec(), + from_at_scalartype(out_dtype), + out_storage); + + const ValueRef r_axis = graph.add_scalar(axis); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_channel.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_axis, + r_quant_min, + r_quant_max, + r_dtype, + r_output_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, + quantized_input.const_data_ptr(), + quantized_input.numel()); + + // copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + } + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + std::cout << " storage: " << in_storage << std::endl; + std::cout << std::endl; + + std::cout << "\033[91m quantized_input: \033[0m" << std::endl; + std::cout << quantized_input << std::endl; + std::cout << "\033[91m aten: \033[0m" << std::endl; + std::cout << aten_out << std::endl; + std::cout << "\033[91m reference: \033[0m" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "\033[91m vulkan: \033[0m" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + TEST( VulkanDequantizePerChannelTest, test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { @@ -1751,3 +2014,413 @@ TEST( at::kInt, at::kFloat); } + +// END OF REFERENCE TESTS + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis0) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(9, 0.1f); + std::vector zero_points(9, 2); + + // 1D Tensor + test_vulkan_dequantize_per_channel( + {9}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 2D Tensor + test_vulkan_dequantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 3D Tensor + test_vulkan_dequantize_per_channel( + {9, 7, 11}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 17, 5, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 17, 5, 9}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis1) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(14, 0.001f); + std::vector zero_points(14, -5); + + // 2D Tensor + test_vulkan_dequantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 3D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 5}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {9, 7, 14, 5}, // input sizes + scales, + zero_points, + -2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis2) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(11, 0.5f); + std::vector zero_points(11, 12); + + // 3D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {9, 11, 14, 5}, // input sizes + scales, + zero_points, + -3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis3) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(7, 0.5f); + std::vector zero_points(7, 12); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 7}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {7, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; + std::vector zero_points = {0, 5, -5, 1, 12}; + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_8bit_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kHalf); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kHalf); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kHalf); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kHalf); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kHalf); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_8bit_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kDouble); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kDouble); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kDouble); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kDouble); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kDouble); +}