diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 64ea144fbf1..8c5246f6c0c 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -48,6 +48,16 @@ Tensor& quantize_per_token_out( ScalarType dtype, Tensor& out); +Tensor& quantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + // Wrapper function for quantize_per_tensor_out without context Tensor& quantize_per_tensor_out_no_context( const Tensor& input, @@ -74,6 +84,20 @@ Tensor& quantize_per_token_out_no_context( input, scale, zero_point, quant_min, quant_max, dtype, out); } +// Wrapper function for quantize_per_channel_out without context +Tensor& quantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_channel_out( + input, scale, zero_point, axis, quant_min, quant_max, dtype, out); +} + // ATen wrapper for quantize_per_tensor at::Tensor quantize_per_tensor_aten( const at::Tensor& input, @@ -106,6 +130,23 @@ at::Tensor quantize_per_token_aten( return out; } +// ATen wrapper for quantize_per_channel +at::Tensor quantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) + (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -160,6 +201,40 @@ void check_quantize_args( quant_max); } +/** + * Helper function to validate quantize_per_channel arguments + * Similar to the validation in op_quantize.cpp + */ +void check_quantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -271,6 +346,110 @@ at::Tensor quantize_per_token_reference_impl( return out; } +/* + * Reference implementation of quantize_per_channel + */ +at::Tensor quantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const float* input_data = input.const_data_ptr(); + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = zero_point.const_data_ptr(); + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = zero_point_data[channel_idx]; + + // Get the input value + float input_value = input_data[flat_idx]; + + // Apply quantization formula: round(input / scale) + zero_point + float inv_scale = 1.0f / static_cast(channel_scale); + int64_t quantized_value = static_cast( + static_cast(channel_zero_point) + + std::nearbyint(static_cast(inv_scale * input_value))); + + // Clamp to quantization bounds + quantized_value = std::max(quantized_value, quant_min); + quantized_value = std::min(quantized_value, quant_max); + + // Store the result based on output dtype + switch (dtype) { + case at::kByte: { + uint8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kChar: { + int8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kShort: { + int16_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kInt: { + int32_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + default: + assert(false && "Unsupported output dtype"); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_quantize_per_tensor_impl( const std::vector& input_sizes, @@ -513,7 +692,10 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -889,7 +1071,10 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -924,7 +1109,7 @@ void test_vulkan_quantize_per_token_impl( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -940,7 +1125,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -956,7 +1141,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -972,7 +1157,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_uint8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -988,7 +1173,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1009,9 +1194,7 @@ TEST( at::kByte); } -TEST( - VulkanQuantizePerTensorTest, - test_vulkan_quantize_per_token_float_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_int8_buffers_support()) { @@ -1032,7 +1215,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32) { std::vector scales = { -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; @@ -1049,7 +1232,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32_small_scales) { std::vector scales = { 0, @@ -1070,7 +1253,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1095,7 +1278,7 @@ TEST( at::kByte); } -TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_float16_buffers_support()) { @@ -1115,7 +1298,7 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_double_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1134,3 +1317,189 @@ TEST( at::kDouble, // input dtype at::kChar); // output dtype } + +void test_reference_quantize_per_channel( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = quantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get direct ATen implementation output + c10::ScalarType aten_dtype = dtype; + if (dtype == at::kChar) { + aten_dtype = c10::kQInt8; + } else if (dtype == at::kByte) { + aten_dtype = c10::kQUInt8; + } + + // Normalize axis for ATen (it doesn't handle negative values) + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + at::Tensor aten_ref = at::quantize_per_channel( + input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor my_ref_int = my_ref.to(at::kInt); + at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); + // For quantized tensors, we need to use int_repr() to get the underlying + // integer values + at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); + + const bool output_correct = at::equal(my_ref_int, cpu_ref_int); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "aten_ref:" << std::endl; + std::cout << aten_ref_int << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref_int << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_quantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +}