From 9a8b16e4c940bd6f7b9658da213e2ed687eb0699 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:16 -0700 Subject: [PATCH 1/4] [ET-VK][Ops] aligning Q/DQ/CQP op inputs with ATen impl Pull Request resolved: https://github.com/pytorch/executorch/pull/12199 # Context A few operators have been recently created, namely: - quantize_per_tensor - quantize_per_token - dequantize_per_tensor - dequantize_per_token - choose_qparams.tensor - choose_qparams_per_token_asymmetric They don't have a namespace associated with them, and since we are trying to align with the ATen implementation in their respective quantized_decomposed namespace, this diff is necessary to align in that regard. Furthermore, our operators need to match inputs with the ATen version, so we also pass dtypes. # Changes The primary change is adding the namespace quantized_decomposed to all the above named operators. Furthermore, we change the testing framework to pass dummy dtypes that is expected for the ATen implementation. We also change the `choose_qparams` logic to properly pass the eps, since this is actually a relevant variable and cannot be set by default, despite the existing op_quantize cpu reference in executorch not distinctly using this variable. ghstack-source-id: 295972783 @exported-using-ghexport Differential Revision: [D77746144](https://our.internmc.facebook.com/intern/diff/D77746144/) --- .../graph/ops/glsl/choose_qparams.glslh | 16 ++--- .../graph/ops/glsl/choose_qparams_buffer.glsl | 5 +- .../ops/glsl/choose_qparams_texture.glsl | 5 +- .../runtime/graph/ops/impl/ChooseQParams.cpp | 67 +++++++++++++------ .../runtime/graph/ops/impl/Dequantize.cpp | 42 ++++++++++-- .../runtime/graph/ops/impl/Quantize.cpp | 35 ++++++++-- .../test/op_tests/choose_qparams_test.cpp | 29 ++++++-- .../vulkan/test/op_tests/dequantize_test.cpp | 14 +++- .../vulkan/test/op_tests/quantize_test.cpp | 12 +++- 9 files changed, 169 insertions(+), 56 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index 66620e9b174..d6d27d2e3a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,15 +9,13 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// equivalent of the eps defined in the cpu implementation -#define SMALL_SCALE_THRESHOLD 6.1e-5 - // Calculate scale and zero point from min and max values void calculate_scale_and_zero_point( float min_val, float max_val, int qmin, int qmax, + float eps_threshold, out float scale_val, out int zero_point_val) { // ensure we have zero included in our range @@ -31,18 +29,18 @@ void calculate_scale_and_zero_point( scale_val = 0.1; } - // Cut off small scale - if (scale_val < SMALL_SCALE_THRESHOLD) { + // Cut off small scale using the provided eps threshold + if (scale_val < eps_threshold) { float org_scale = scale_val; - scale_val = SMALL_SCALE_THRESHOLD; + scale_val = eps_threshold; // Adjust min and max based on new scale if (min_val == 0.0) { - max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + max_val = eps_threshold * float(qmax - qmin); } else if (max_val == 0.0) { - min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + min_val = -eps_threshold * float(qmax - qmin); } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + float amplifier = eps_threshold / org_scale; min_val *= amplifier; max_val *= amplifier; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index dcbfe493f34..48681a46c30 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -29,6 +29,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -175,7 +176,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; @@ -260,7 +261,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val; diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 282f1de170a..5076b2d68e9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -30,6 +30,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -234,7 +235,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); @@ -372,7 +373,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 1dc2d34afbf..5e9599b91e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -150,6 +150,7 @@ void add_choose_qparams_tensor_node( const ValueRef& input, const ValueRef& quant_min, const ValueRef& quant_max, + const ValueRef& eps, const ValueRef& scale_out, const ValueRef& zero_point_out) { std::string kernel_name("choose_qparams_tensor"); @@ -158,6 +159,7 @@ void add_choose_qparams_tensor_node( int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; @@ -180,6 +182,7 @@ void add_choose_qparams_tensor_node( push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float)), }; graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -275,8 +278,22 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused dtype parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -289,13 +306,10 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -303,7 +317,7 @@ void choose_qparams_tensor_impl( } add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, scale_out, zero_point_out); + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); } void choose_qparams_per_token_asymmetric_impl( @@ -311,8 +325,21 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -325,22 +352,20 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } REGISTER_OPERATORS { - VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( - choose_qparams_per_token_asymmetric.default, + quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 77a51ce24f9..3838da9a151 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -180,8 +180,15 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -212,8 +219,15 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -257,18 +271,34 @@ void dequantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_dequantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); - VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_token.default, + dequantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 49277b4d718..f8f930bf0fb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -180,8 +180,12 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -205,8 +209,12 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -243,18 +251,33 @@ void quantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_quantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); - VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 55e96151387..75b7cbc8960 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -433,14 +433,23 @@ void test_vulkan_choose_qparams_tensor_impl( const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams.tensor") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add eps and dtype parameters to match ATen signature + const ValueRef r_eps = graph.add_scalar(6.1e-5); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") (graph, { r_input.value, r_quant_min, r_quant_max, - r_scale, - r_zero_point, + r_eps, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); @@ -647,12 +656,20 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl( const ValueRef r_zero_point = graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add dtype parameter to match ATen signature + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN( + "quantized_decomposed.choose_qparams_per_token_asymmetric.default") (graph, { r_input.value, - r_scale, - r_zero_point, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6c604076c41..82f316abe82 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -585,7 +585,10 @@ void test_vulkan_dequantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.default") (graph, { r_input.value, @@ -593,6 +596,8 @@ void test_vulkan_dequantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); @@ -1046,7 +1051,10 @@ void test_vulkan_dequantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") (graph, { r_input.value, @@ -1054,6 +1062,8 @@ void test_vulkan_dequantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 150bda6989e..64ea144fbf1 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -476,7 +476,10 @@ void test_vulkan_quantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.default") (graph, { r_input.value, @@ -484,6 +487,7 @@ void test_vulkan_quantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -835,7 +839,10 @@ void test_vulkan_quantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") (graph, { r_input.value, @@ -843,6 +850,7 @@ void test_vulkan_quantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, r_out, }); From 2673309be65e29439e6862d62eb40dfa95c355d9 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:18 -0700 Subject: [PATCH 2/4] [ET-VK][ez][Ops] registering Q/DQ/CQP ops and specifying optimal storage Pull Request resolved: https://github.com/pytorch/executorch/pull/12200 # Context Certain quantization operators need scales and zeros to be set with a storage layout as buffers. Since the existing op_registry does not allow specifying how input parameters are set with their memory or storage layout, we need to specify that the optimal storage type is buffer so that is conversion pass is added to ensure that the inputs are also buffers. # Changes This moves the quantized_decomposed operators in their own registration, while also specifying that buffer is preferred. ghstack-source-id: 295972779 @exported-using-ghexport Differential Revision: [D77746131](https://our.internmc.facebook.com/intern/diff/D77746131/) --- backends/vulkan/op_registry.py | 36 +++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a63d178e2d..1f77b30cda3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey): @update_features( [ operator.getitem, - # Quantization related ops will be fused via graph passes - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_token.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, + ] +) +def register_quantization_op(features: OpFeatures): + # Quantization requires buffer storage and width packing for scales/zero_points + # but we need to provide texture impl features for the partitioner to work properly + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor, From c640c4fd09136b316e275ae3d6b9b7ecf072e8d9 Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:20 -0700 Subject: [PATCH 3/4] [ET-VK][ez] enabling fp64->fp32 converison for vulkan compatibility Pull Request resolved: https://github.com/pytorch/executorch/pull/12201 # Context We need this conversion so that certain operators can handle floating point values that need to be 64bit. This is predominantly applicable to choose_qparams.tensor where it expects a 64bit output. # Changes Simply adding an additional conversion for float64 to vulkan fp32. ghstack-source-id: 295972781 @exported-using-ghexport Differential Revision: [D77746137](https://our.internmc.facebook.com/intern/diff/D77746137/) --- backends/vulkan/runtime/VulkanBackend.cpp | 4 ++++ backends/vulkan/serialization/schema.fbs | 2 ++ .../serialization/vulkan_graph_builder.py | 20 +++++++++++++++---- .../serialization/vulkan_graph_schema.py | 2 ++ backends/vulkan/vulkan_preprocess.py | 9 +++++++-- 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7077a9df59c..28e7574537c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kChar; case vkgraph::VkDataType::INT32: return vkapi::kInt; + case vkgraph::VkDataType::INT64: + return vkapi::kLong; case vkgraph::VkDataType::FLOAT16: return vkapi::kHalf; case vkgraph::VkDataType::FLOAT32: return vkapi::kFloat; + case vkgraph::VkDataType::FLOAT64: + return vkapi::kDouble; } } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index f112581c498..99ba6a86594 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -18,6 +18,8 @@ enum VkDataType : byte { INT32 = 3, FLOAT16 = 4, FLOAT32 = 5, + FLOAT64 = 6, + INT64 = 7, } // Describes what kind of GPU resource should be used to represent a tensor. The diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 5bae0475c28..cd876bd6305 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -45,9 +45,11 @@ def __init__( self, program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, + downcast_64_bit: bool = True, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder + self.downcast_64_bit = downcast_64_bit self.chain = [] self.values = [] self.input_ids = [] @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: return vk_graph_schema.VkDataType.INT8 elif torch_dtype == torch.int32: return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT64 elif torch_dtype == torch.float16: return vk_graph_schema.VkDataType.FLOAT16 elif torch_dtype == torch.float32: return vk_graph_schema.VkDataType.FLOAT32 - # Narrowing conversion for index tensor produced by max_poolNd_with_indices. - elif torch_dtype == torch.int64: - return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float64: + return vk_graph_schema.VkDataType.FLOAT64 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # pyre-ignore[16] memory_layout = spec.vk_memory_layout + # Apply downcast logic before getting VK datatype + effective_dtype = spec.dtype + if self.downcast_64_bit and spec.dtype == torch.float64: + effective_dtype = torch.float32 + elif self.downcast_64_bit and spec.dtype == torch.int64: + effective_dtype = torch.int32 + + datatype = self.get_vk_datatype(effective_dtype) + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( - datatype=self.get_vk_datatype(spec.dtype), + datatype=datatype, dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 35113bc623a..f845e5601a7 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -29,6 +29,8 @@ class VkDataType(IntEnum): INT32 = 3 FLOAT16 = 4 FLOAT32 = 5 + FLOAT64 = 6 + INT64 = 7 class VkStorageType(IntEnum): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a22afc3f42e..a6d5737dbb8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -67,7 +67,6 @@ # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): new_gm = program.graph_module # This is a workaround to allow the memory planning pass to work without @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "skip_tag_memory_metadata": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "downcast_64_bit": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -142,6 +144,7 @@ def preprocess( # noqa: C901 default_memory_layout = compile_options.get( "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) + downcast_64_bit = compile_options.get("downcast_64_bit", True) program = unsafe_remove_auto_functionalized_pass(program) @@ -213,7 +216,9 @@ def preprocess( # noqa: C901 ) graph_builder = VkGraphBuilder( - program, DelegateMappingBuilder(generated_identifiers=True) + program, + DelegateMappingBuilder(generated_identifiers=True), + downcast_64_bit=downcast_64_bit, ) vk_graph = graph_builder.build_graph() From 97f4606dc0071ded1f49ee58b0a7ec3d5168b84e Mon Sep 17 00:00:00 2001 From: morelos Date: Sun, 13 Jul 2025 21:36:24 -0700 Subject: [PATCH 4/4] [ET] correcting cpu ref quantize_per_channel logic to align with ATen Pull Request resolved: https://github.com/pytorch/executorch/pull/12203 # Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. ghstack-source-id: 295972782 @exported-using-ghexport Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/) --- kernels/quantized/cpu/op_quantize.cpp | 68 ++---- kernels/quantized/cpu/targets.bzl | 6 - kernels/quantized/test/op_quantize_test.cpp | 240 ++++++++++++++++++++ 3 files changed, 263 insertions(+), 51 deletions(-) diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index d0b7c882f8e..5586f8a77eb 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include @@ -282,55 +281,34 @@ Tensor& quantize_per_channel_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // a list contains all dimensions except axis - int64_t dims[kTensorDimensionLimit]; - for (int64_t i = 0; i < input.dim() - 1; i++) { - if (i < axis) { - dims[i] = i; - } else { - dims[i] = i - 1; - } - } const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - std::optional> optional_dim_list{ - executorch::aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] + // High-performance single loop with direct channel calculation #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - double _scale = scale_data[channel_ix]; \ - int64_t _zero_point = zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + const int64_t axis_size = input.size(axis); \ + /* Calculate the stride pattern for efficient channel index calculation */ \ + int64_t axis_block_size = 1; \ + for (int64_t i = axis + 1; i < input.dim(); i++) { \ + axis_block_size *= input.size(i); \ } \ - break; + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ + } \ + } break; + #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index 3ba9715506a..f29f1f013b7 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -51,12 +51,6 @@ _QUANT_OPS = ( ), op_target( name = "op_quantize", - deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util", - ], - _aten_mode_deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util_aten", - ], ), ) diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 5cd17223d80..4ac835c24ce 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 2}, 4); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {100, 50, 25}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 2}); + // Channel 0: 4 / 0.5 + 100 = 108 + // Channel 1: 4 / 1.0 + 50 = 54 + // Channel 2: 4 / 2.0 + 25 = 27 + Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27}); + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel3D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 3D tensor with axis=1 (middle dimension) + Tensor input = tf_float.full({2, 3, 4}, 6); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3, 4}); + // Channel 0: 6 / 0.5 + 10 = 22 + // Channel 1: 6 / 1.0 + 20 = 26 + // Channel 2: 6 / 1.5 + 30 = 34 + Tensor expected = tfo.make( + {2, 3, 4}, + { + 22, 22, 22, 22, // First batch, channel 0 + 26, 26, 26, 26, // First batch, channel 1 + 34, 34, 34, 34, // First batch, channel 2 + 22, 22, 22, 22, // Second batch, channel 0 + 26, 26, 26, 26, // Second batch, channel 1 + 34, 34, 34, 34 // Second batch, channel 2 + }); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel4D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W) + Tensor input = tf_float.full({2, 2, 3, 2}, 8); + Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2, 3, 2}); + // Channel 0: 8 / 0.25 + 0 = 32 + // Channel 1: 8 / 0.5 + 10 = 26 + // Channel 2: 8 / 1.0 + 20 = 28 + std::vector expected_data; + for (int n = 0; n < 2; n++) { + for (int c = 0; c < 2; c++) { + for (int h = 0; h < 3; h++) { + for (int w = 0; w < 2; w++) { + int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28; + expected_data.push_back(val); + } + } + } + } + Tensor expected = tfo.make({2, 2, 3, 2}, expected_data); + quantize_per_channel_out( + input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 3}, 5); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Using axis=-1 should be equivalent to axis=1 for 2D tensor + // Channel 0: 5 / 0.5 + 0 = 10 + // Channel 1: 5 / 1.0 + 10 = 15 + // Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5) + Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22}); + quantize_per_channel_out( + input, + scale, + zero_point, + -1, + quant_min, + quant_max, + ScalarType::Byte, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 1, 4}, 7); + Tensor scale = tf_double.make({1}, {0.5}); + Tensor zero_point = tf_long.make({1}, {128}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 1, 4}); + // Single channel: 7 / 0.5 + 128 = 142 + Tensor expected = tfo.full({3, 1, 4}, 142); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) { + TensorFactory tf_double_input; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_double_input.full({2, 2}, 3.14159); + Tensor scale = tf_double.make({2}, {0.01, 0.02}); + Tensor zero_point = tf_long.make({2}, {0, 100}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127 + // Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127 + Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 2}, 10); + Tensor scale = tf_double.make({2}, {1.0, 2.0}); + Tensor zero_point = tf_long.make({2}, {1000, 2000}); + int64_t quant_min = -32768; + int64_t quant_max = 32767; + + // Test with 16-bit output + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 10 / 1.0 + 1000 = 1010 + // Channel 1: 10 / 2.0 + 2000 = 2005 + Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005}); + quantize_per_channel_out( + input, + scale, + zero_point, + 1, + quant_min, + quant_max, + ScalarType::Short, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test with different input values per position + Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32] + // Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34] + Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test values that will exceed quant_min/quant_max bounds + Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0}); + Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 0, 0}); + int64_t quant_min = -10; + int64_t quant_max = 10; + + TensorFactory tfo; + Tensor out = tfo.zeros({1, 3}); + // Values: [-100, 0, 100] should be clamped to [-10, 0, 10] + Tensor expected = tfo.make({1, 3}, {-10, 0, 10}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +}