From 9a8b16e4c940bd6f7b9658da213e2ed687eb0699 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:16 -0700 Subject: [PATCH 1/8] [ET-VK][Ops] aligning Q/DQ/CQP op inputs with ATen impl Pull Request resolved: https://github.com/pytorch/executorch/pull/12199 # Context A few operators have been recently created, namely: - quantize_per_tensor - quantize_per_token - dequantize_per_tensor - dequantize_per_token - choose_qparams.tensor - choose_qparams_per_token_asymmetric They don't have a namespace associated with them, and since we are trying to align with the ATen implementation in their respective quantized_decomposed namespace, this diff is necessary to align in that regard. Furthermore, our operators need to match inputs with the ATen version, so we also pass dtypes. # Changes The primary change is adding the namespace quantized_decomposed to all the above named operators. Furthermore, we change the testing framework to pass dummy dtypes that is expected for the ATen implementation. We also change the `choose_qparams` logic to properly pass the eps, since this is actually a relevant variable and cannot be set by default, despite the existing op_quantize cpu reference in executorch not distinctly using this variable. ghstack-source-id: 295972783 @exported-using-ghexport Differential Revision: [D77746144](https://our.internmc.facebook.com/intern/diff/D77746144/) --- .../graph/ops/glsl/choose_qparams.glslh | 16 ++--- .../graph/ops/glsl/choose_qparams_buffer.glsl | 5 +- .../ops/glsl/choose_qparams_texture.glsl | 5 +- .../runtime/graph/ops/impl/ChooseQParams.cpp | 67 +++++++++++++------ .../runtime/graph/ops/impl/Dequantize.cpp | 42 ++++++++++-- .../runtime/graph/ops/impl/Quantize.cpp | 35 ++++++++-- .../test/op_tests/choose_qparams_test.cpp | 29 ++++++-- .../vulkan/test/op_tests/dequantize_test.cpp | 14 +++- .../vulkan/test/op_tests/quantize_test.cpp | 12 +++- 9 files changed, 169 insertions(+), 56 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index 66620e9b174..d6d27d2e3a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,15 +9,13 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// equivalent of the eps defined in the cpu implementation -#define SMALL_SCALE_THRESHOLD 6.1e-5 - // Calculate scale and zero point from min and max values void calculate_scale_and_zero_point( float min_val, float max_val, int qmin, int qmax, + float eps_threshold, out float scale_val, out int zero_point_val) { // ensure we have zero included in our range @@ -31,18 +29,18 @@ void calculate_scale_and_zero_point( scale_val = 0.1; } - // Cut off small scale - if (scale_val < SMALL_SCALE_THRESHOLD) { + // Cut off small scale using the provided eps threshold + if (scale_val < eps_threshold) { float org_scale = scale_val; - scale_val = SMALL_SCALE_THRESHOLD; + scale_val = eps_threshold; // Adjust min and max based on new scale if (min_val == 0.0) { - max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + max_val = eps_threshold * float(qmax - qmin); } else if (max_val == 0.0) { - min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + min_val = -eps_threshold * float(qmax - qmin); } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + float amplifier = eps_threshold / org_scale; min_val *= amplifier; max_val *= amplifier; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index dcbfe493f34..48681a46c30 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -29,6 +29,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -175,7 +176,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; @@ -260,7 +261,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val; diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 282f1de170a..5076b2d68e9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -30,6 +30,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -234,7 +235,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); @@ -372,7 +373,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 1dc2d34afbf..5e9599b91e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -150,6 +150,7 @@ void add_choose_qparams_tensor_node( const ValueRef& input, const ValueRef& quant_min, const ValueRef& quant_max, + const ValueRef& eps, const ValueRef& scale_out, const ValueRef& zero_point_out) { std::string kernel_name("choose_qparams_tensor"); @@ -158,6 +159,7 @@ void add_choose_qparams_tensor_node( int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; @@ -180,6 +182,7 @@ void add_choose_qparams_tensor_node( push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float)), }; graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -275,8 +278,22 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused dtype parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -289,13 +306,10 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -303,7 +317,7 @@ void choose_qparams_tensor_impl( } add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, scale_out, zero_point_out); + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); } void choose_qparams_per_token_asymmetric_impl( @@ -311,8 +325,21 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -325,22 +352,20 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } REGISTER_OPERATORS { - VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( - choose_qparams_per_token_asymmetric.default, + quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 77a51ce24f9..3838da9a151 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -180,8 +180,15 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -212,8 +219,15 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -257,18 +271,34 @@ void dequantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_dequantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); - VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_token.default, + dequantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 49277b4d718..f8f930bf0fb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -180,8 +180,12 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -205,8 +209,12 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -243,18 +251,33 @@ void quantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_quantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); - VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 55e96151387..75b7cbc8960 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -433,14 +433,23 @@ void test_vulkan_choose_qparams_tensor_impl( const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams.tensor") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add eps and dtype parameters to match ATen signature + const ValueRef r_eps = graph.add_scalar(6.1e-5); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") (graph, { r_input.value, r_quant_min, r_quant_max, - r_scale, - r_zero_point, + r_eps, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); @@ -647,12 +656,20 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl( const ValueRef r_zero_point = graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add dtype parameter to match ATen signature + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN( + "quantized_decomposed.choose_qparams_per_token_asymmetric.default") (graph, { r_input.value, - r_scale, - r_zero_point, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6c604076c41..82f316abe82 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -585,7 +585,10 @@ void test_vulkan_dequantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.default") (graph, { r_input.value, @@ -593,6 +596,8 @@ void test_vulkan_dequantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); @@ -1046,7 +1051,10 @@ void test_vulkan_dequantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") (graph, { r_input.value, @@ -1054,6 +1062,8 @@ void test_vulkan_dequantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 150bda6989e..64ea144fbf1 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -476,7 +476,10 @@ void test_vulkan_quantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.default") (graph, { r_input.value, @@ -484,6 +487,7 @@ void test_vulkan_quantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -835,7 +839,10 @@ void test_vulkan_quantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") (graph, { r_input.value, @@ -843,6 +850,7 @@ void test_vulkan_quantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, r_out, }); From 2673309be65e29439e6862d62eb40dfa95c355d9 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:18 -0700 Subject: [PATCH 2/8] [ET-VK][ez][Ops] registering Q/DQ/CQP ops and specifying optimal storage Pull Request resolved: https://github.com/pytorch/executorch/pull/12200 # Context Certain quantization operators need scales and zeros to be set with a storage layout as buffers. Since the existing op_registry does not allow specifying how input parameters are set with their memory or storage layout, we need to specify that the optimal storage type is buffer so that is conversion pass is added to ensure that the inputs are also buffers. # Changes This moves the quantized_decomposed operators in their own registration, while also specifying that buffer is preferred. ghstack-source-id: 295972779 @exported-using-ghexport Differential Revision: [D77746131](https://our.internmc.facebook.com/intern/diff/D77746131/) --- backends/vulkan/op_registry.py | 36 +++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a63d178e2d..1f77b30cda3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey): @update_features( [ operator.getitem, - # Quantization related ops will be fused via graph passes - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_token.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, + ] +) +def register_quantization_op(features: OpFeatures): + # Quantization requires buffer storage and width packing for scales/zero_points + # but we need to provide texture impl features for the partitioner to work properly + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor, From c640c4fd09136b316e275ae3d6b9b7ecf072e8d9 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:20 -0700 Subject: [PATCH 3/8] [ET-VK][ez] enabling fp64->fp32 converison for vulkan compatibility Pull Request resolved: https://github.com/pytorch/executorch/pull/12201 # Context We need this conversion so that certain operators can handle floating point values that need to be 64bit. This is predominantly applicable to choose_qparams.tensor where it expects a 64bit output. # Changes Simply adding an additional conversion for float64 to vulkan fp32. ghstack-source-id: 295972781 @exported-using-ghexport Differential Revision: [D77746137](https://our.internmc.facebook.com/intern/diff/D77746137/) --- backends/vulkan/runtime/VulkanBackend.cpp | 4 ++++ backends/vulkan/serialization/schema.fbs | 2 ++ .../serialization/vulkan_graph_builder.py | 20 +++++++++++++++---- .../serialization/vulkan_graph_schema.py | 2 ++ backends/vulkan/vulkan_preprocess.py | 9 +++++++-- 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7077a9df59c..28e7574537c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kChar; case vkgraph::VkDataType::INT32: return vkapi::kInt; + case vkgraph::VkDataType::INT64: + return vkapi::kLong; case vkgraph::VkDataType::FLOAT16: return vkapi::kHalf; case vkgraph::VkDataType::FLOAT32: return vkapi::kFloat; + case vkgraph::VkDataType::FLOAT64: + return vkapi::kDouble; } } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index f112581c498..99ba6a86594 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -18,6 +18,8 @@ enum VkDataType : byte { INT32 = 3, FLOAT16 = 4, FLOAT32 = 5, + FLOAT64 = 6, + INT64 = 7, } // Describes what kind of GPU resource should be used to represent a tensor. The diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 5bae0475c28..cd876bd6305 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -45,9 +45,11 @@ def __init__( self, program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, + downcast_64_bit: bool = True, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder + self.downcast_64_bit = downcast_64_bit self.chain = [] self.values = [] self.input_ids = [] @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: return vk_graph_schema.VkDataType.INT8 elif torch_dtype == torch.int32: return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT64 elif torch_dtype == torch.float16: return vk_graph_schema.VkDataType.FLOAT16 elif torch_dtype == torch.float32: return vk_graph_schema.VkDataType.FLOAT32 - # Narrowing conversion for index tensor produced by max_poolNd_with_indices. - elif torch_dtype == torch.int64: - return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float64: + return vk_graph_schema.VkDataType.FLOAT64 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # pyre-ignore[16] memory_layout = spec.vk_memory_layout + # Apply downcast logic before getting VK datatype + effective_dtype = spec.dtype + if self.downcast_64_bit and spec.dtype == torch.float64: + effective_dtype = torch.float32 + elif self.downcast_64_bit and spec.dtype == torch.int64: + effective_dtype = torch.int32 + + datatype = self.get_vk_datatype(effective_dtype) + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( - datatype=self.get_vk_datatype(spec.dtype), + datatype=datatype, dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 35113bc623a..f845e5601a7 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -29,6 +29,8 @@ class VkDataType(IntEnum): INT32 = 3 FLOAT16 = 4 FLOAT32 = 5 + FLOAT64 = 6 + INT64 = 7 class VkStorageType(IntEnum): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a22afc3f42e..a6d5737dbb8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -67,7 +67,6 @@ # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): new_gm = program.graph_module # This is a workaround to allow the memory planning pass to work without @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "skip_tag_memory_metadata": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "downcast_64_bit": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -142,6 +144,7 @@ def preprocess( # noqa: C901 default_memory_layout = compile_options.get( "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) + downcast_64_bit = compile_options.get("downcast_64_bit", True) program = unsafe_remove_auto_functionalized_pass(program) @@ -213,7 +216,9 @@ def preprocess( # noqa: C901 ) graph_builder = VkGraphBuilder( - program, DelegateMappingBuilder(generated_identifiers=True) + program, + DelegateMappingBuilder(generated_identifiers=True), + downcast_64_bit=downcast_64_bit, ) vk_graph = graph_builder.build_graph() From 97f4606dc0071ded1f49ee58b0a7ec3d5168b84e Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:24 -0700 Subject: [PATCH 4/8] [ET] correcting cpu ref quantize_per_channel logic to align with ATen Pull Request resolved: https://github.com/pytorch/executorch/pull/12203 # Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. ghstack-source-id: 295972782 @exported-using-ghexport Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/) --- kernels/quantized/cpu/op_quantize.cpp | 68 ++---- kernels/quantized/cpu/targets.bzl | 6 - kernels/quantized/test/op_quantize_test.cpp | 240 ++++++++++++++++++++ 3 files changed, 263 insertions(+), 51 deletions(-) diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index d0b7c882f8e..5586f8a77eb 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include @@ -282,55 +281,34 @@ Tensor& quantize_per_channel_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // a list contains all dimensions except axis - int64_t dims[kTensorDimensionLimit]; - for (int64_t i = 0; i < input.dim() - 1; i++) { - if (i < axis) { - dims[i] = i; - } else { - dims[i] = i - 1; - } - } const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - std::optional> optional_dim_list{ - executorch::aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] + // High-performance single loop with direct channel calculation #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - double _scale = scale_data[channel_ix]; \ - int64_t _zero_point = zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + const int64_t axis_size = input.size(axis); \ + /* Calculate the stride pattern for efficient channel index calculation */ \ + int64_t axis_block_size = 1; \ + for (int64_t i = axis + 1; i < input.dim(); i++) { \ + axis_block_size *= input.size(i); \ } \ - break; + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ + } \ + } break; + #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index 3ba9715506a..f29f1f013b7 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -51,12 +51,6 @@ _QUANT_OPS = ( ), op_target( name = "op_quantize", - deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util", - ], - _aten_mode_deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util_aten", - ], ), ) diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 5cd17223d80..4ac835c24ce 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 2}, 4); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {100, 50, 25}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 2}); + // Channel 0: 4 / 0.5 + 100 = 108 + // Channel 1: 4 / 1.0 + 50 = 54 + // Channel 2: 4 / 2.0 + 25 = 27 + Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27}); + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel3D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 3D tensor with axis=1 (middle dimension) + Tensor input = tf_float.full({2, 3, 4}, 6); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3, 4}); + // Channel 0: 6 / 0.5 + 10 = 22 + // Channel 1: 6 / 1.0 + 20 = 26 + // Channel 2: 6 / 1.5 + 30 = 34 + Tensor expected = tfo.make( + {2, 3, 4}, + { + 22, 22, 22, 22, // First batch, channel 0 + 26, 26, 26, 26, // First batch, channel 1 + 34, 34, 34, 34, // First batch, channel 2 + 22, 22, 22, 22, // Second batch, channel 0 + 26, 26, 26, 26, // Second batch, channel 1 + 34, 34, 34, 34 // Second batch, channel 2 + }); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel4D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W) + Tensor input = tf_float.full({2, 2, 3, 2}, 8); + Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2, 3, 2}); + // Channel 0: 8 / 0.25 + 0 = 32 + // Channel 1: 8 / 0.5 + 10 = 26 + // Channel 2: 8 / 1.0 + 20 = 28 + std::vector expected_data; + for (int n = 0; n < 2; n++) { + for (int c = 0; c < 2; c++) { + for (int h = 0; h < 3; h++) { + for (int w = 0; w < 2; w++) { + int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28; + expected_data.push_back(val); + } + } + } + } + Tensor expected = tfo.make({2, 2, 3, 2}, expected_data); + quantize_per_channel_out( + input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 3}, 5); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Using axis=-1 should be equivalent to axis=1 for 2D tensor + // Channel 0: 5 / 0.5 + 0 = 10 + // Channel 1: 5 / 1.0 + 10 = 15 + // Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5) + Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22}); + quantize_per_channel_out( + input, + scale, + zero_point, + -1, + quant_min, + quant_max, + ScalarType::Byte, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 1, 4}, 7); + Tensor scale = tf_double.make({1}, {0.5}); + Tensor zero_point = tf_long.make({1}, {128}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 1, 4}); + // Single channel: 7 / 0.5 + 128 = 142 + Tensor expected = tfo.full({3, 1, 4}, 142); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) { + TensorFactory tf_double_input; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_double_input.full({2, 2}, 3.14159); + Tensor scale = tf_double.make({2}, {0.01, 0.02}); + Tensor zero_point = tf_long.make({2}, {0, 100}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127 + // Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127 + Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 2}, 10); + Tensor scale = tf_double.make({2}, {1.0, 2.0}); + Tensor zero_point = tf_long.make({2}, {1000, 2000}); + int64_t quant_min = -32768; + int64_t quant_max = 32767; + + // Test with 16-bit output + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 10 / 1.0 + 1000 = 1010 + // Channel 1: 10 / 2.0 + 2000 = 2005 + Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005}); + quantize_per_channel_out( + input, + scale, + zero_point, + 1, + quant_min, + quant_max, + ScalarType::Short, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test with different input values per position + Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32] + // Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34] + Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test values that will exceed quant_min/quant_max bounds + Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0}); + Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 0, 0}); + int64_t quant_min = -10; + int64_t quant_max = 10; + + TensorFactory tfo; + Tensor out = tfo.zeros({1, 3}); + // Values: [-100, 0, 100] should be clamped to [-10, 0, 10] + Tensor expected = tfo.make({1, 3}, {-10, 0, 10}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} From 52a110c3a59763ccb78bc99b98ea4467fbf1e0aa Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:29 -0700 Subject: [PATCH 5/8] [ET-VK][Ops] quantize_per_channel reference impl and testing Pull Request resolved: https://github.com/pytorch/executorch/pull/12204 # Context In order to properly enable dynamic quantization, we create the quantize_per_channel operator as its seemingly useful to have for the pipeline. # Changes This creates the wrapper for the cpu reference implementation, and also a dummy reference implementation I created just to test against it. ghstack-source-id: 295972785 @exported-using-ghexport Differential Revision: [D77746132](https://our.internmc.facebook.com/intern/diff/D77746132/) --- .../vulkan/test/op_tests/quantize_test.cpp | 399 +++++++++++++++++- 1 file changed, 384 insertions(+), 15 deletions(-) diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 64ea144fbf1..8c5246f6c0c 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -48,6 +48,16 @@ Tensor& quantize_per_token_out( ScalarType dtype, Tensor& out); +Tensor& quantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + // Wrapper function for quantize_per_tensor_out without context Tensor& quantize_per_tensor_out_no_context( const Tensor& input, @@ -74,6 +84,20 @@ Tensor& quantize_per_token_out_no_context( input, scale, zero_point, quant_min, quant_max, dtype, out); } +// Wrapper function for quantize_per_channel_out without context +Tensor& quantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_channel_out( + input, scale, zero_point, axis, quant_min, quant_max, dtype, out); +} + // ATen wrapper for quantize_per_tensor at::Tensor quantize_per_tensor_aten( const at::Tensor& input, @@ -106,6 +130,23 @@ at::Tensor quantize_per_token_aten( return out; } +// ATen wrapper for quantize_per_channel +at::Tensor quantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) + (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -160,6 +201,40 @@ void check_quantize_args( quant_max); } +/** + * Helper function to validate quantize_per_channel arguments + * Similar to the validation in op_quantize.cpp + */ +void check_quantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -271,6 +346,110 @@ at::Tensor quantize_per_token_reference_impl( return out; } +/* + * Reference implementation of quantize_per_channel + */ +at::Tensor quantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const float* input_data = input.const_data_ptr(); + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = zero_point.const_data_ptr(); + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = zero_point_data[channel_idx]; + + // Get the input value + float input_value = input_data[flat_idx]; + + // Apply quantization formula: round(input / scale) + zero_point + float inv_scale = 1.0f / static_cast(channel_scale); + int64_t quantized_value = static_cast( + static_cast(channel_zero_point) + + std::nearbyint(static_cast(inv_scale * input_value))); + + // Clamp to quantization bounds + quantized_value = std::max(quantized_value, quant_min); + quantized_value = std::min(quantized_value, quant_max); + + // Store the result based on output dtype + switch (dtype) { + case at::kByte: { + uint8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kChar: { + int8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kShort: { + int16_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kInt: { + int32_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + default: + assert(false && "Unsupported output dtype"); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_quantize_per_tensor_impl( const std::vector& input_sizes, @@ -513,7 +692,10 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -889,7 +1071,10 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -924,7 +1109,7 @@ void test_vulkan_quantize_per_token_impl( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -940,7 +1125,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -956,7 +1141,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -972,7 +1157,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_uint8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -988,7 +1173,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1009,9 +1194,7 @@ TEST( at::kByte); } -TEST( - VulkanQuantizePerTensorTest, - test_vulkan_quantize_per_token_float_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_int8_buffers_support()) { @@ -1032,7 +1215,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32) { std::vector scales = { -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; @@ -1049,7 +1232,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32_small_scales) { std::vector scales = { 0, @@ -1070,7 +1253,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1095,7 +1278,7 @@ TEST( at::kByte); } -TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_float16_buffers_support()) { @@ -1115,7 +1298,7 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_double_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1134,3 +1317,189 @@ TEST( at::kDouble, // input dtype at::kChar); // output dtype } + +void test_reference_quantize_per_channel( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = quantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get direct ATen implementation output + c10::ScalarType aten_dtype = dtype; + if (dtype == at::kChar) { + aten_dtype = c10::kQInt8; + } else if (dtype == at::kByte) { + aten_dtype = c10::kQUInt8; + } + + // Normalize axis for ATen (it doesn't handle negative values) + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + at::Tensor aten_ref = at::quantize_per_channel( + input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor my_ref_int = my_ref.to(at::kInt); + at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); + // For quantized tensors, we need to use int_repr() to get the underlying + // integer values + at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); + + const bool output_correct = at::equal(my_ref_int, cpu_ref_int); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "aten_ref:" << std::endl; + std::cout << aten_ref_int << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref_int << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_quantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} From 6597aa248a464b66a072fc19434337c3c6816177 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:40 -0700 Subject: [PATCH 6/8] [ET-VK][Ops] quantize_per_channel shaders and impl Pull Request resolved: https://github.com/pytorch/executorch/pull/12205 # Context We need to enable the core logic for quantize_per_channel in the vulkan shader. This implements the shader itself and its cpp header. TODO: add more of a description regarding the operator # Changes This creates an extension of the existing files for quantize_per_channel. ghstack-source-id: 295972786 @exported-using-ghexport Differential Revision: [D77746140](https://our.internmc.facebook.com/intern/diff/D77746140/) --- .../graph/ops/glsl/quantize_buffer.glsl | 51 +- .../graph/ops/glsl/quantize_buffer.yaml | 2 + .../graph/ops/glsl/quantize_texture.glsl | 96 ++- .../graph/ops/glsl/quantize_texture.yaml | 2 + .../runtime/graph/ops/impl/Quantize.cpp | 226 ++++++- .../vulkan/test/op_tests/quantize_test.cpp | 625 ++++++++++++++++++ 6 files changed, 995 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl index ea0c2f7dce7..c3e58286efe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -42,6 +42,16 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -137,7 +147,7 @@ void quantize_per_tensor() { t_out[out_bufi] = qvalue; } -#else +#elif defined(per_token) void quantize_per_token() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -172,6 +182,45 @@ void quantize_per_token() { t_out[out_bufi] = qvalue; } +#else // per_channel + +void quantize_per_channel() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + // Calculate channel index based on the quantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (tidx.x) + // axis 1 -> H dimension (tidx.y) + // axis 2 -> C dimension (tidx.z) + // axis 3 -> N dimension (tidx.w) + int channel_idx = 0; + + if (axis == 0) { + channel_idx = out_tidx.x; + } else if (axis == 1) { + channel_idx = out_tidx.y; + } else if (axis == 2) { + channel_idx = out_tidx.z; + } else if (axis == 3) { + channel_idx = out_tidx.w; + } + + channel_idx = min(channel_idx, num_channels - 1); + + OUT_T qvalue = quantize_val(value, t_scale[channel_idx], t_zero_point[channel_idx]); + + t_out[out_bufi] = qvalue; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 4d95d610314..1dd8e6e2ffe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -17,3 +17,5 @@ quantize_buffer: MODE: per_tensor - NAME: quantize_per_token_buffer MODE: per_token + - NAME: quantize_per_channel_buffer + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl index 9ba7074f75b..bdaba3ffaf9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -26,6 +26,8 @@ ${define_required_extensions(OUT_DTYPE)} layout(std430) buffer; +#include "indexing_utils.h" + ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} @@ -45,11 +47,23 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} -#include "indexing_utils.h" +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + #include "quantize.glslh" layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -138,7 +152,7 @@ void quantize_per_tensor() { write_texel(t_out, pos, outtex); } -#else +#elif defined(per_token) void quantize_per_token() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -177,6 +191,84 @@ void quantize_per_token() { write_texel(t_out, pos, outtex); } +#else // per_channel + +void quantize_per_channel() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + // Calculate channel index based on the quantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component) + // axis 1 -> H dimension (pos.y) + // axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding + // axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding + + if (axis == 0) { + // Width dimension - each texel component has different channel index + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + int channel_idx = pos.x * 4 + i; + channel_idx = min(channel_idx, num_channels - 1); + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } else if (axis == 1) { + // Height dimension - all texel components use same channel index + int channel_idx = pos.y; + channel_idx = min(channel_idx, num_channels - 1); + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } else if (axis == 2) { + // Channel dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual channel index from the folded dimension + int folded_idx = pos.z; + int channel_idx = folded_idx % num_channels; + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } else if (axis == 3) { + // Batch dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual batch index from the folded dimension + int folded_idx = pos.z; + int batch_idx = folded_idx / num_channels; + + float scale_val = t_scale[batch_idx]; + int zero_point_val = t_zero_point[batch_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 65002ce26b6..47e532be8b9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -17,3 +17,5 @@ quantize_texture: MODE: per_tensor - NAME: quantize_per_token_texture3d MODE: per_token + - NAME: quantize_per_channel_texture3d + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index f8f930bf0fb..74dee249b0a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -12,11 +12,10 @@ #include #include -#include -namespace vkcompute { +#include -namespace { +namespace vkcompute { void resize_quantize_output( ComputeGraph* graph, @@ -28,7 +27,52 @@ void resize_quantize_output( graph->virtual_resize(out, graph->sizes_of(in)); } -} // namespace +utils::uvec3 quantize_per_channel_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + + return global_wg_size; +} + +utils::uvec3 quantize_per_channel_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)args; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel quantization along the batch axis, + // we need to ensure that we dispatch the correct number of workgroups in the + // Z dimension to cover all batch-channel combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + local_wg_size[2] = 1; + } + + return local_wg_size; +} void add_quantize_per_tensor_node( ComputeGraph& graph, @@ -171,6 +215,99 @@ void add_quantize_per_token_node( resize_quantize_output)); } +void add_quantize_per_channel_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& axis, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_channel"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int axis_val = static_cast(graph.get_int(axis)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + // Normalize axis and convert from NCHW to WHCN using utility functions + const auto input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + // Normalize axis to handle negative indices + axis_val = normalize(axis_val, ndim); + + // Convert from NCHW axis to WHCN axis for shader (vulkan representation) + int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); + + int num_channels; + if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { + // For batch dimension quantization in 4D tensors, pass the actual number of + // channels so the shader can correctly unfold the batch-channel folding + num_channels = static_cast(input_sizes[1]); // Channel dimension + } else { + num_channels = static_cast(input_sizes[axis_val]); + } + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantize_per_channel_global_wg_size, + quantize_per_channel_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + void quantize_per_tensor_impl( ComputeGraph& graph, const std::vector& args) { @@ -272,12 +409,93 @@ void quantize_per_token_impl( graph, input, scale, zero_point, quant_min, quant_max, output); } +void quantize_per_channel_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef axis = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Normalize axis + int axis_val = static_cast(graph.get_int(axis)); + const auto input_sizes = graph.sizes_of(input); + int64_t ndim = graph.dim_of(input); + if (axis_val < 0) { + axis_val += ndim; + } + + // Verify axis is valid + VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); + + // Get number of channels along the specified axis + int64_t num_channels = input_sizes[axis_val]; + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_channels + VK_CHECK_COND(scale_numel == num_channels); + VK_CHECK_COND(zero_point_numel == num_channels); + + add_quantize_per_channel_node( + graph, input, scale, zero_point, axis, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.quantize_per_tensor.default, quantize_per_tensor_impl); VK_REGISTER_OP( quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_channel.default, + quantize_per_channel_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 8c5246f6c0c..ebb12bc1b3a 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -473,6 +473,18 @@ void test_vulkan_quantize_per_token_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_quantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_quantize_per_tensor( const std::vector& input_sizes, @@ -553,6 +565,48 @@ void test_vulkan_quantize_per_token( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + + test_vulkan_quantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_quantize_per_tensor( const std::vector& input_sizes, float scale, @@ -1436,6 +1490,167 @@ void test_reference_quantize_per_channel( ASSERT_TRUE(output_correct); } +void test_vulkan_quantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::quantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_axis = graph.add_scalar(axis); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_channel.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_axis, + r_quant_min, + r_quant_max, + r_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + TEST( VulkanQuantizePerChannelTest, test_reference_quantize_per_channel_float_to_int8_3D_axis0) { @@ -1503,3 +1718,413 @@ TEST( at::kFloat, at::kChar); } + +// END OF REFERENCE TESTS + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis0) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(9, 0.1f); + std::vector zero_points(9, 2); + + // 1D Tensor + test_vulkan_quantize_per_channel( + {9}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 2D Tensor + test_vulkan_quantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 3D Tensor + test_vulkan_quantize_per_channel( + {9, 7, 11}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 17, 5, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 17, 5, 9}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis1) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(14, 0.001f); + std::vector zero_points(14, -5); + + // 2D Tensor + test_vulkan_quantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 3D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 5}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {9, 7, 14, 5}, // input sizes + scales, + zero_points, + -2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis2) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(11, 0.5f); + std::vector zero_points(11, 12); + + // 3D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {9, 11, 14, 5}, // input sizes + scales, + zero_points, + -3, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis3) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(7, 0.5f); + std::vector zero_points(7, 12); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 7}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {7, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; + std::vector zero_points = {0, 5, -5, 1, 12}; + + // 4D Tensor + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_half_to_8bit) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kHalf, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kHalf, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kHalf, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_double_to_8bit) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kDouble, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kDouble, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kDouble, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kDouble, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kDouble, + at::kByte); +} From 85648e00a838c47db6e7126e0404075455e886e9 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:49 -0700 Subject: [PATCH 7/8] [ET-VK][Ops] dequantize_per_channel reference impl and testing Pull Request resolved: https://github.com/pytorch/executorch/pull/12206 # Context In order to properly enable dynamic quantization, we create the dequantize_per_channel operator as its seemingly useful to have for the pipeline. To provide some more context on the ATen to ETen change, there was an issue that the optionals did not perfectly handle cases that were const and ref, so this change primarily plans to add functionality to handle these templating mismatch cases. # Changes This creates the wrapper for the cpu reference implementation, and also a dummy reference implementation I created just to test against it. We also created a test case for ATen to ETen for the new changes. ghstack-source-id: 295972788 @exported-using-ghexport Differential Revision: [D77746138](https://our.internmc.facebook.com/intern/diff/D77746138/) --- .../vulkan/test/op_tests/dequantize_test.cpp | 406 +++++++++++++++++- .../make_aten_functor_from_et_functor.h | 40 +- ...make_aten_functor_from_et_functor_test.cpp | 89 ++++ 3 files changed, 523 insertions(+), 12 deletions(-) diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 82f316abe82..f32a93e2b6a 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -49,6 +49,17 @@ Tensor& dequantize_per_token_out( ScalarType out_dtype, Tensor& out); +Tensor& dequantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + // Wrapper function for dequantize_per_tensor_out without context Tensor& dequantize_per_tensor_out_no_context( const Tensor& input, @@ -77,6 +88,29 @@ Tensor& dequantize_per_token_out_no_context( input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); } +// Wrapper function for dequantize_per_channel_out without context +Tensor& dequantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_channel_out( + input, + scale, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + // ATen wrapper for dequantize_per_tensor at::Tensor dequantize_per_tensor_aten( const at::Tensor& input, @@ -131,6 +165,36 @@ at::Tensor dequantize_per_token_aten( return out; } +// ATen wrapper for dequantize_per_channel +at::Tensor dequantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) + (input, + scale, + zero_points, + axis, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -183,6 +247,40 @@ void check_dequantize_args( } } +/** + * Helper function to validate dequantize_per_channel arguments + * Similar to the validation in quantize_test.cpp + */ +void check_dequantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -322,6 +420,120 @@ at::Tensor dequantize_per_token_reference_impl( return out; } +/* + * Reference implementation of dequantize_per_channel + */ +at::Tensor dequantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, out_dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = nullptr; + if (zero_point.has_value()) { + zero_point_data = zero_point.value().const_data_ptr(); + } + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = 0; + if (zero_point_data != nullptr) { + channel_zero_point = zero_point_data[channel_idx]; + } + + // Store casted values to avoid repeated casting + const int32_t channel_zero_point_int32 = + static_cast(channel_zero_point); + const float channel_scale_float = static_cast(channel_scale); + + // Get the input value and dequantize + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store the result based on output dtype + if (out_dtype == at::kFloat) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + output.flatten()[flat_idx] = dequantized_value; + } else if (out_dtype == at::kHalf) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_dequantize_per_tensor_impl( const std::vector& input_sizes, @@ -625,7 +837,8 @@ void test_vulkan_dequantize_per_tensor_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1105,7 +1318,8 @@ void test_vulkan_dequantize_per_token_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1349,3 +1563,191 @@ TEST( at::kChar, // input dtype at::kDouble); // output dtype } + +void test_reference_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create input tensor with quantized values + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = dequantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(my_ref, cpu_ref); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_dequantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, + at::kFloat); +} diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index cb7b36a5fc1..104531f0fbb 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -155,19 +155,39 @@ struct type_convert< }; // Optionals: ATen to ETen. -template -struct type_convert, torch::executor::optional> final { +template +struct type_convert< + AOptional, + EOptional, + std::enable_if_t< + std::is_same_v< + typename remove_const_ref::type, + std::optional< + typename remove_const_ref::type::value_type>> && + std::is_same_v< + typename remove_const_ref::type, + torch::executor::optional< + typename remove_const_ref::type::value_type>>>> + final { public: - std::optional val; - std::unique_ptr> convert_struct; - explicit type_convert(std::optional value) : val(value) {} - torch::executor::optional call() { + typename remove_const_ref::type val; + std::unique_ptr::type::value_type, + typename remove_const_ref::type::value_type>> + convert_struct; + explicit type_convert(AOptional value) : val(value) {} + typename remove_const_ref::type call() { if (val.has_value()) { - convert_struct = std::make_unique>( - type_convert(val.value())); - return torch::executor::optional(convert_struct->call()); + convert_struct = std::make_unique::type::value_type, + typename remove_const_ref::type::value_type>>( + type_convert< + typename remove_const_ref::type::value_type, + typename remove_const_ref::type::value_type>( + val.value())); + return typename remove_const_ref::type(convert_struct->call()); } else { - return torch::executor::optional(); + return typename remove_const_ref::type(); } } }; diff --git a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp index 17d0f7a4d63..a5b53096ae2 100644 --- a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp +++ b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp @@ -421,3 +421,92 @@ TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) { EXPECT_EQ(stack.size(), 1); EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 4); } + +TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ConstRefOptionals) { + // Test const optional scalar conversion + const std::optional const_optional_at_in = + std::optional(42); + auto const_optional_et = + type_convert< + const std::optional, + torch::executor::optional>(const_optional_at_in) + .call(); + EXPECT_TRUE(const_optional_et.has_value()); + EXPECT_EQ(const_optional_et.value(), 42); + + // Test optional scalar reference conversion + std::optional optional_at_ref_in = std::optional(24); + auto optional_et_from_ref = + type_convert&, torch::executor::optional>( + optional_at_ref_in) + .call(); + EXPECT_TRUE(optional_et_from_ref.has_value()); + EXPECT_EQ(optional_et_from_ref.value(), 24); + + // Test const optional scalar reference conversion + const std::optional const_optional_at_ref_in = + std::optional(84); + auto const_optional_et_from_ref = + type_convert< + const std::optional&, + torch::executor::optional>(const_optional_at_ref_in) + .call(); + EXPECT_TRUE(const_optional_et_from_ref.has_value()); + EXPECT_EQ(const_optional_et_from_ref.value(), 84); + + // Test const optional tensor conversion + const std::optional const_optional_tensor_at_in = + std::optional(torch::tensor({5})); + auto const_optional_tensor_converter = type_convert< + const std::optional, + torch::executor::optional>( + const_optional_tensor_at_in); + auto const_optional_tensor_et = const_optional_tensor_converter.call(); + EXPECT_TRUE(const_optional_tensor_et.has_value()); + EXPECT_EQ(const_optional_tensor_et.value().const_data_ptr()[0], 5); + + // Test optional tensor reference conversion + std::optional optional_tensor_at_ref_in = + std::optional(torch::tensor({7})); + auto optional_tensor_converter_from_ref = type_convert< + std::optional&, + torch::executor::optional>( + optional_tensor_at_ref_in); + auto optional_tensor_et_from_ref = optional_tensor_converter_from_ref.call(); + EXPECT_TRUE(optional_tensor_et_from_ref.has_value()); + EXPECT_EQ( + optional_tensor_et_from_ref.value().const_data_ptr()[0], 7); + + // Test const optional tensor reference conversion + const std::optional const_optional_tensor_at_ref_in = + std::optional(torch::tensor({9})); + auto const_optional_tensor_converter_from_ref = type_convert< + const std::optional&, + torch::executor::optional>( + const_optional_tensor_at_ref_in); + auto const_optional_tensor_et_from_ref = + const_optional_tensor_converter_from_ref.call(); + EXPECT_TRUE(const_optional_tensor_et_from_ref.has_value()); + EXPECT_EQ( + const_optional_tensor_et_from_ref.value().const_data_ptr()[0], + 9); + + // Test empty const optional conversions + const std::optional empty_const_optional_at_in = std::nullopt; + auto empty_const_optional_et = + type_convert< + const std::optional, + torch::executor::optional>(empty_const_optional_at_in) + .call(); + EXPECT_FALSE(empty_const_optional_et.has_value()); + + const std::optional empty_const_optional_tensor_at_in = + std::nullopt; + auto empty_const_optional_tensor_et = + type_convert< + const std::optional, + torch::executor::optional>( + empty_const_optional_tensor_at_in) + .call(); + EXPECT_FALSE(empty_const_optional_tensor_et.has_value()); +} From 504baa95f19b0795c62dca81f2bb5e8bed948878 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:55 -0700 Subject: [PATCH 8/8] [ET-VK][Ops] dequantize_per_channel shaders and impl Pull Request resolved: https://github.com/pytorch/executorch/pull/12207 # Context We need to enable the core logic for dequantize_per_channel in the vulkan shader. This implements the shader itself and its cpp header. TODO: add more of a description regarding the operator # Changes This creates an extension of the existing files for dequantize_per_channel. ghstack-source-id: 295972778 @exported-using-ghexport Differential Revision: [D77746141](https://our.internmc.facebook.com/intern/diff/D77746141/) --- .../graph/ops/glsl/dequantize_buffer.glsl | 51 +- .../graph/ops/glsl/dequantize_buffer.yaml | 2 + .../graph/ops/glsl/dequantize_texture.glsl | 103 ++- .../graph/ops/glsl/dequantize_texture.yaml | 2 + .../runtime/graph/ops/impl/Dequantize.cpp | 232 +++++- .../vulkan/test/op_tests/dequantize_test.cpp | 673 ++++++++++++++++++ 6 files changed, 1058 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index 2a1f62719a0..faafa3fd266 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -42,6 +42,16 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -141,7 +151,7 @@ void dequantize_per_tensor() { t_out[out_bufi] = value; } -#else +#elif defined(per_token) void dequantize_per_token() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -176,6 +186,45 @@ void dequantize_per_token() { t_out[out_bufi] = value; } +#else // per_channel + +void dequantize_per_channel() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + // Calculate channel index based on the dequantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (tidx.x) + // axis 1 -> H dimension (tidx.y) + // axis 2 -> C dimension (tidx.z) + // axis 3 -> N dimension (tidx.w) + int channel_idx = 0; + + if (axis == 0) { + channel_idx = out_tidx.x; + } else if (axis == 1) { + channel_idx = out_tidx.y; + } else if (axis == 2) { + channel_idx = out_tidx.z; + } else if (axis == 3) { + channel_idx = out_tidx.w; + } + + channel_idx = min(channel_idx, num_channels - 1); + + OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]); + + t_out[out_bufi] = value; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index fb0d2ee61bf..b9a53217452 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -17,3 +17,5 @@ dequantize_buffer: MODE: per_tensor - NAME: dequantize_per_token_buffer MODE: per_token + - NAME: dequantize_per_channel_buffer + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index 801f4a2f6a2..ef3f5cca869 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -45,6 +45,16 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -147,7 +157,7 @@ void dequantize_per_tensor() { write_texel(t_out, pos, outtex); } -#else +#elif defined(per_token) void dequantize_per_token() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -189,6 +199,97 @@ void dequantize_per_token() { write_texel(t_out, pos, outtex); } +#else // per_channel + +void dequantize_per_channel() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + // Calculate channel index based on the dequantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (pos.x) + // axis 1 -> H dimension (pos.y) + // axis 2 -> C dimension (pos.z) + // axis 3 -> N dimension (batch folding in texture storage) + + if (axis == 0) { + // Width dimension - each texel component has different channel index + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + int channel_idx = pos.x * 4 + i; + channel_idx = min(channel_idx, num_channels - 1); + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } else if (axis == 1) { + int channel_idx = pos.y; + channel_idx = min(channel_idx, num_channels - 1); + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } else if (axis == 2) { + // Channel dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual channel index from the folded dimension + int folded_idx = pos.z; + int channel_idx = folded_idx % num_channels; + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } else if (axis == 3) { + // Batch dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual channel index from the folded dimension + int folded_idx = pos.z; + // In this case num_channels actually corresponds to the number of channels + // the C dimension N(C)HW + int channel_idx = folded_idx / num_channels; + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 7d19a543a03..88ccc6e3274 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -17,3 +17,5 @@ dequantize_texture: MODE: per_tensor - NAME: dequantize_per_token_texture3d MODE: per_token + - NAME: dequantize_per_channel_texture3d + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 3838da9a151..8845d6f6254 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -13,11 +13,10 @@ #include #include #include +#include namespace vkcompute { -namespace { - void resize_dequantize_output( ComputeGraph* graph, const std::vector& args, @@ -28,7 +27,50 @@ void resize_dequantize_output( graph->virtual_resize(out, graph->sizes_of(in)); } -} // namespace +utils::uvec3 dequantize_per_channel_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + + return global_wg_size; +} + +utils::uvec3 dequantize_per_channel_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel dequantization along the batch + // axis, we need to ensure that we dispatch the correct number of workgroups + // in the Z dimension to cover all batch-channel combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + local_wg_size[2] = 1; + } + + return local_wg_size; +} void add_dequantize_per_tensor_node( ComputeGraph& graph, @@ -171,6 +213,99 @@ void add_dequantize_per_token_node( resize_dequantize_output)); } +void add_dequantize_per_channel_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& axis, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_channel"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int axis_val = static_cast(graph.get_int(axis)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + // Normalize axis and convert from NCHW to WHCN using utility functions + const auto input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + // Normalize axis to handle negative indices + axis_val = normalize(axis_val, ndim); + + // Convert from NCHW axis to WHCN axis for shader (vulkan representation) + int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); + + int num_channels; + if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { + // For batch dimension dequantization in 4D tensors, pass the actual number + // of channels so the shader can correctly unfold the batch-channel folding + num_channels = static_cast(input_sizes[1]); // Channel dimension + } else { + num_channels = static_cast(input_sizes[axis_val]); + } + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + dequantize_per_channel_global_wg_size, + dequantize_per_channel_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + void dequantize_per_tensor_impl( ComputeGraph& graph, const std::vector& args) { @@ -292,6 +427,94 @@ void dequantize_per_token_impl( graph, input, scale, zero_point, quant_min, quant_max, output); } +void dequantize_per_channel_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef axis = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Normalize axis + int axis_val = static_cast(graph.get_int(axis)); + const auto input_sizes = graph.sizes_of(input); + int ndim = graph.dim_of(input); + if (axis_val < 0) { + axis_val += ndim; + } + + // Verify axis is valid + VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); + + // Get number of channels along the specified axis + int64_t num_channels = input_sizes[axis_val]; + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_channels + VK_CHECK_COND(scale_numel == num_channels); + VK_CHECK_COND(zero_point_numel == num_channels); + + add_dequantize_per_channel_node( + graph, input, scale, zero_point, axis, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_tensor.default, @@ -299,6 +522,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_channel.default, + dequantize_per_channel_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index f32a93e2b6a..cb9c04ee089 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -557,6 +557,18 @@ void test_vulkan_dequantize_per_token_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_dequantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_dequantize_per_tensor( const std::vector& input_sizes, @@ -637,6 +649,49 @@ void test_vulkan_dequantize_per_token( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_dequantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_dequantize_per_tensor( const std::vector& input_sizes, float scale, @@ -1684,6 +1739,214 @@ void test_reference_dequantize_per_channel( ASSERT_TRUE(output_correct); } +void test_vulkan_dequantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create random float tensor + at::Tensor float_x = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Map the dtype to the corresponding quantized type and quantize the float + // tensor + c10::ScalarType qtype; + at::Tensor adjusted_zero_points = zero_point_tensor; + + if (dtype == at::kByte) { + qtype = c10::kQUInt8; + // ATEN ONLY: Adjust zero points for unsigned types (must be non-negative) + adjusted_zero_points = at::clamp_min(zero_point_tensor, 0); + } else if (dtype == at::kChar) { + qtype = c10::kQInt8; + } else if (dtype == at::kInt) { + qtype = c10::kQInt32; + } else { + std::cout << "invalid dtype for ATEN: " << dtype << std::endl; + std::cout << " --> Delegating to c10::kQInt32" << std::endl; + qtype = c10::kQInt32; + } + + // Normalize axis for ATen (ATen doesn't handle negative axes in + // quantize_per_channel) + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes_int64.size(); + } + + // Quantize using ATen + at::Tensor quantized_aten = at::quantize_per_channel( + float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype); + + // Get ATen dequantized output + at::Tensor aten_out = at::dequantize(quantized_aten).to(out_dtype); + + // Extract the quantized values (int_repr) to use with our implementations + at::Tensor quantized_input = quantized_aten.int_repr().to(dtype); + + // Get reference output using + // torch::executor::native::dequantize_per_channel_aten + at::Tensor reference_out = + torch::executor::native::dequantize_per_channel_aten( + quantized_input, + scale_tensor.to(at::kDouble), + zero_point_tensor.to(at::kLong), + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Build Vulkan dequantize_per_channel graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + // Add tensors to graph + IOValueRef r_input = graph.add_input_tensor( + quantized_input.sizes().vec(), + from_at_scalartype(quantized_input.scalar_type()), + in_storage); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + + IOValueRef r_zero_point = graph.add_input_tensor( + adjusted_zero_points.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + ValueRef r_out = graph.add_tensor( + quantized_input.sizes().vec(), + from_at_scalartype(out_dtype), + out_storage); + + const ValueRef r_axis = graph.add_scalar(axis); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_channel.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_axis, + r_quant_min, + r_quant_max, + r_dtype, + r_output_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, + quantized_input.const_data_ptr(), + quantized_input.numel()); + + // copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + } + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + std::cout << " storage: " << in_storage << std::endl; + std::cout << std::endl; + + std::cout << "\033[91m quantized_input: \033[0m" << std::endl; + std::cout << quantized_input << std::endl; + std::cout << "\033[91m aten: \033[0m" << std::endl; + std::cout << aten_out << std::endl; + std::cout << "\033[91m reference: \033[0m" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "\033[91m vulkan: \033[0m" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + TEST( VulkanDequantizePerChannelTest, test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { @@ -1751,3 +2014,413 @@ TEST( at::kInt, at::kFloat); } + +// END OF REFERENCE TESTS + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis0) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(9, 0.1f); + std::vector zero_points(9, 2); + + // 1D Tensor + test_vulkan_dequantize_per_channel( + {9}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 2D Tensor + test_vulkan_dequantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 3D Tensor + test_vulkan_dequantize_per_channel( + {9, 7, 11}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 17, 5, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 17, 5, 9}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis1) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(14, 0.001f); + std::vector zero_points(14, -5); + + // 2D Tensor + test_vulkan_dequantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 3D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 5}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {9, 7, 14, 5}, // input sizes + scales, + zero_points, + -2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis2) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(11, 0.5f); + std::vector zero_points(11, 12); + + // 3D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {9, 11, 14, 5}, // input sizes + scales, + zero_points, + -3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_int8_to_float_axis3) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(7, 0.5f); + std::vector zero_points(7, 12); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 7}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {7, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; + std::vector zero_points = {0, 5, -5, 1, 12}; + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_8bit_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kHalf); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kHalf); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kHalf); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kHalf); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kHalf); +} + +TEST( + VulkanDequantizePerChannelTest, + test_vulkan_dequantize_per_channel_8bit_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kDouble); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kDouble); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kDouble); + + // 4D Tensor + test_vulkan_dequantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kDouble); + + // 4D Tensor (negative axis) + test_vulkan_dequantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kDouble); +}