@@ -100,7 +100,15 @@ Tensor& dequantize_per_channel_out_no_context(
100100 executorch::aten::optional<ScalarType> out_dtype,
101101 Tensor& out) {
102102 return torch::executor::native::dequantize_per_channel_out (
103- input, scale, zero_points, axis, quant_min, quant_max, dtype, out_dtype, out);
103+ input,
104+ scale,
105+ zero_points,
106+ axis,
107+ quant_min,
108+ quant_max,
109+ dtype,
110+ out_dtype,
111+ out);
104112}
105113
106114// ATen wrapper for dequantize_per_tensor
@@ -480,7 +488,8 @@ at::Tensor dequantize_per_channel_reference_impl(
480488 }
481489
482490 // Store casted values to avoid repeated casting
483- const int32_t channel_zero_point_int32 = static_cast <int32_t >(channel_zero_point);
491+ const int32_t channel_zero_point_int32 =
492+ static_cast <int32_t >(channel_zero_point);
484493 const float channel_scale_float = static_cast <float >(channel_scale);
485494
486495 // Get the input value and dequantize
@@ -490,19 +499,24 @@ at::Tensor dequantize_per_channel_reference_impl(
490499 // Following the CPU implementation pattern: (input - zero_point) * scale
491500 if (dtype == at::kByte ) {
492501 uint8_t qvalue = input.flatten ()[flat_idx].item <uint8_t >();
493- dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
502+ dequantized_value =
503+ (qvalue - channel_zero_point_int32) * channel_scale_float;
494504 } else if (dtype == at::kChar ) {
495505 int8_t qvalue = input.flatten ()[flat_idx].item <int8_t >();
496- dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
506+ dequantized_value =
507+ (qvalue - channel_zero_point_int32) * channel_scale_float;
497508 } else if (dtype == at::kShort ) {
498509 int16_t qvalue = input.flatten ()[flat_idx].item <int16_t >();
499- dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
510+ dequantized_value =
511+ (qvalue - channel_zero_point_int32) * channel_scale_float;
500512 } else if (dtype == at::kInt ) {
501513 int32_t qvalue = input.flatten ()[flat_idx].item <int32_t >();
502- dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
514+ dequantized_value =
515+ (qvalue - channel_zero_point_int32) * channel_scale_float;
503516 } else if (dtype == at::kLong ) {
504517 int64_t qvalue = input.flatten ()[flat_idx].item <int64_t >();
505- dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float;
518+ dequantized_value =
519+ (qvalue - channel_zero_point_int32) * channel_scale_float;
506520 } else {
507521 throw std::runtime_error (" Unsupported input dtype" );
508522 }
@@ -878,7 +892,8 @@ void test_vulkan_dequantize_per_tensor_impl(
878892 output_correct =
879893 at::allclose (reference_out, vk_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 );
880894 } else {
881- output_correct = at::allclose (reference_out, vk_out, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 );
895+ output_correct =
896+ at::allclose (reference_out, vk_out, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 );
882897 }
883898 if (!output_correct) {
884899 std::cout << " \n "
@@ -1358,7 +1373,8 @@ void test_vulkan_dequantize_per_token_impl(
13581373 output_correct =
13591374 at::allclose (reference_out, vk_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 );
13601375 } else {
1361- output_correct = at::allclose (reference_out, vk_out, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 );
1376+ output_correct =
1377+ at::allclose (reference_out, vk_out, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 );
13621378 }
13631379 if (!output_correct) {
13641380 std::cout << " \n "
@@ -1737,16 +1753,21 @@ void test_vulkan_dequantize_per_channel_impl(
17371753 check_dequantize_args (quant_min, quant_max, dtype, out_dtype);
17381754 check_dequantize_per_channel_args (input_sizes, scales, zero_points, axis);
17391755
1740- std::vector<int64_t > input_sizes_int64 (input_sizes.begin (), input_sizes.end ());
1756+ std::vector<int64_t > input_sizes_int64 (
1757+ input_sizes.begin (), input_sizes.end ());
17411758
17421759 // Create random float tensor
1743- at::Tensor float_x = at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
1760+ at::Tensor float_x =
1761+ at::rand (input_sizes_int64, at::device (at::kCPU ).dtype (at::kFloat ));
17441762
17451763 // Create scale and zero_point tensors
1746- at::Tensor scale_tensor = at::tensor (scales, at::device (at::kCPU ).dtype (at::kFloat ));
1747- at::Tensor zero_point_tensor = at::tensor (zero_points, at::device (at::kCPU ).dtype (at::kInt ));
1764+ at::Tensor scale_tensor =
1765+ at::tensor (scales, at::device (at::kCPU ).dtype (at::kFloat ));
1766+ at::Tensor zero_point_tensor =
1767+ at::tensor (zero_points, at::device (at::kCPU ).dtype (at::kInt ));
17481768
1749- // Map the dtype to the corresponding quantized type and quantize the float tensor
1769+ // Map the dtype to the corresponding quantized type and quantize the float
1770+ // tensor
17501771 c10::ScalarType qtype;
17511772 at::Tensor adjusted_zero_points = zero_point_tensor;
17521773
@@ -1764,36 +1785,35 @@ void test_vulkan_dequantize_per_channel_impl(
17641785 qtype = c10::kQInt32 ;
17651786 }
17661787
1767- // Normalize axis for ATen (ATen doesn't handle negative axes in quantize_per_channel)
1788+ // Normalize axis for ATen (ATen doesn't handle negative axes in
1789+ // quantize_per_channel)
17681790 int64_t normalized_axis = axis;
17691791 if (normalized_axis < 0 ) {
17701792 normalized_axis += input_sizes_int64.size ();
17711793 }
17721794
17731795 // Quantize using ATen
17741796 at::Tensor quantized_aten = at::quantize_per_channel (
1775- float_x,
1776- scale_tensor,
1777- adjusted_zero_points,
1778- normalized_axis,
1779- qtype);
1797+ float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype);
17801798
17811799 // Get ATen dequantized output
17821800 at::Tensor aten_out = at::dequantize (quantized_aten).to (out_dtype);
17831801
17841802 // Extract the quantized values (int_repr) to use with our implementations
17851803 at::Tensor quantized_input = quantized_aten.int_repr ().to (dtype);
17861804
1787- // Get reference output using torch::executor::native::dequantize_per_channel_aten
1788- at::Tensor reference_out = torch::executor::native::dequantize_per_channel_aten (
1789- quantized_input,
1790- scale_tensor.to (at::kDouble ),
1791- zero_point_tensor.to (at::kLong ),
1792- axis,
1793- quant_min,
1794- quant_max,
1795- dtype,
1796- out_dtype);
1805+ // Get reference output using
1806+ // torch::executor::native::dequantize_per_channel_aten
1807+ at::Tensor reference_out =
1808+ torch::executor::native::dequantize_per_channel_aten (
1809+ quantized_input,
1810+ scale_tensor.to (at::kDouble ),
1811+ zero_point_tensor.to (at::kLong ),
1812+ axis,
1813+ quant_min,
1814+ quant_max,
1815+ dtype,
1816+ out_dtype);
17971817
17981818 // Build Vulkan dequantize_per_channel graph
17991819 using namespace vkcompute ;
@@ -1828,8 +1848,10 @@ void test_vulkan_dequantize_per_channel_impl(
18281848 const ValueRef r_axis = graph.add_scalar <int64_t >(axis);
18291849 const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
18301850 const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
1831- const ValueRef r_dtype = graph.add_scalar <int64_t >(static_cast <int64_t >(dtype));
1832- const ValueRef r_output_dtype = graph.add_scalar <int64_t >(static_cast <int64_t >(out_dtype));
1851+ const ValueRef r_dtype =
1852+ graph.add_scalar <int64_t >(static_cast <int64_t >(dtype));
1853+ const ValueRef r_output_dtype =
1854+ graph.add_scalar <int64_t >(static_cast <int64_t >(out_dtype));
18331855
18341856 VK_GET_OP_FN (" quantized_decomposed.dequantize_per_channel.default" )
18351857 (graph,
@@ -1854,7 +1876,9 @@ void test_vulkan_dequantize_per_channel_impl(
18541876
18551877 // Copy input data to GPU
18561878 graph.copy_into_staging (
1857- r_input.staging , quantized_input.const_data_ptr (), quantized_input.numel ());
1879+ r_input.staging ,
1880+ quantized_input.const_data_ptr (),
1881+ quantized_input.numel ());
18581882
18591883 // copy scale tensor to GPU
18601884 graph.copy_into_staging (
@@ -1881,7 +1905,8 @@ void test_vulkan_dequantize_per_channel_impl(
18811905 output_correct =
18821906 at::allclose (reference_out, vk_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 );
18831907 } else {
1884- output_correct = at::allclose (reference_out, vk_out, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 );
1908+ output_correct =
1909+ at::allclose (reference_out, vk_out, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 );
18851910 }
18861911 if (!output_correct) {
18871912 std::cout << " \n "
@@ -1992,7 +2017,9 @@ TEST(
19922017
19932018// END OF REFERENCE TESTS
19942019
1995- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis0) {
2020+ TEST (
2021+ VulkanDequantizePerChannelTest,
2022+ test_vulkan_dequantize_per_channel_int8_to_float_axis0) {
19962023 std::vector<float > scales (9 , 0 .1f );
19972024 std::vector<int > zero_points (9 , 2 );
19982025
@@ -2052,7 +2079,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
20522079 at::kFloat );
20532080}
20542081
2055- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis1) {
2082+ TEST (
2083+ VulkanDequantizePerChannelTest,
2084+ test_vulkan_dequantize_per_channel_int8_to_float_axis1) {
20562085 std::vector<float > scales (14 , 0 .001f );
20572086 std::vector<int > zero_points (14 , -5 );
20582087
@@ -2101,7 +2130,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
21012130 at::kFloat );
21022131}
21032132
2104- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis2) {
2133+ TEST (
2134+ VulkanDequantizePerChannelTest,
2135+ test_vulkan_dequantize_per_channel_int8_to_float_axis2) {
21052136 std::vector<float > scales (11 , 0 .5f );
21062137 std::vector<int > zero_points (11 , 12 );
21072138
@@ -2139,7 +2170,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
21392170 at::kFloat );
21402171}
21412172
2142- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_float_axis3) {
2173+ TEST (
2174+ VulkanDequantizePerChannelTest,
2175+ test_vulkan_dequantize_per_channel_int8_to_float_axis3) {
21432176 std::vector<float > scales (7 , 0 .5f );
21442177 std::vector<int > zero_points (7 , 12 );
21452178
@@ -2166,7 +2199,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_int8_to_
21662199 at::kFloat );
21672200}
21682201
2169- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) {
2202+ TEST (
2203+ VulkanDequantizePerChannelTest,
2204+ test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) {
21702205 std::vector<float > scales = {0.1 , 0.2 , 0.0001 , 0.5 , 0.02 };
21712206 std::vector<int > zero_points = {0 , 5 , -5 , 1 , 12 };
21722207
@@ -2226,7 +2261,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_uint8_to
22262261 at::kFloat );
22272262}
22282263
2229- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_half) {
2264+ TEST (
2265+ VulkanDequantizePerChannelTest,
2266+ test_vulkan_dequantize_per_channel_8bit_to_half) {
22302267 std::vector<float > scales = {0.1 , 0.2 , 0.01 , 0.5 , 0.02 };
22312268 std::vector<int > zero_points = {0 , 5 , 5 , 1 , 12 };
22322269
@@ -2286,7 +2323,9 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_
22862323 at::kHalf );
22872324}
22882325
2289- TEST (VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_double) {
2326+ TEST (
2327+ VulkanDequantizePerChannelTest,
2328+ test_vulkan_dequantize_per_channel_8bit_to_double) {
22902329 std::vector<float > scales = {0.1 , 0.2 , 0.01 , 0.5 , 0.02 };
22912330 std::vector<int > zero_points = {0 , 5 , 5 , 1 , 12 };
22922331
0 commit comments