diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a63d178e2d..1f77b30cda3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey): @update_features( [ operator.getitem, - # Quantization related ops will be fused via graph passes - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_token.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, + ] +) +def register_quantization_op(features: OpFeatures): + # Quantization requires buffer storage and width packing for scales/zero_points + # but we need to provide texture impl features for the partitioner to work properly + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor, diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7077a9df59c..28e7574537c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kChar; case vkgraph::VkDataType::INT32: return vkapi::kInt; + case vkgraph::VkDataType::INT64: + return vkapi::kLong; case vkgraph::VkDataType::FLOAT16: return vkapi::kHalf; case vkgraph::VkDataType::FLOAT32: return vkapi::kFloat; + case vkgraph::VkDataType::FLOAT64: + return vkapi::kDouble; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index 66620e9b174..d6d27d2e3a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,15 +9,13 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// equivalent of the eps defined in the cpu implementation -#define SMALL_SCALE_THRESHOLD 6.1e-5 - // Calculate scale and zero point from min and max values void calculate_scale_and_zero_point( float min_val, float max_val, int qmin, int qmax, + float eps_threshold, out float scale_val, out int zero_point_val) { // ensure we have zero included in our range @@ -31,18 +29,18 @@ void calculate_scale_and_zero_point( scale_val = 0.1; } - // Cut off small scale - if (scale_val < SMALL_SCALE_THRESHOLD) { + // Cut off small scale using the provided eps threshold + if (scale_val < eps_threshold) { float org_scale = scale_val; - scale_val = SMALL_SCALE_THRESHOLD; + scale_val = eps_threshold; // Adjust min and max based on new scale if (min_val == 0.0) { - max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + max_val = eps_threshold * float(qmax - qmin); } else if (max_val == 0.0) { - min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + min_val = -eps_threshold * float(qmax - qmin); } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + float amplifier = eps_threshold / org_scale; min_val *= amplifier; max_val *= amplifier; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index dcbfe493f34..48681a46c30 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -29,6 +29,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -175,7 +176,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; @@ -260,7 +261,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val; diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 282f1de170a..5076b2d68e9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -30,6 +30,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -234,7 +235,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); @@ -372,7 +373,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl index ea0c2f7dce7..c3e58286efe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -42,6 +42,16 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -137,7 +147,7 @@ void quantize_per_tensor() { t_out[out_bufi] = qvalue; } -#else +#elif defined(per_token) void quantize_per_token() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -172,6 +182,45 @@ void quantize_per_token() { t_out[out_bufi] = qvalue; } +#else // per_channel + +void quantize_per_channel() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + // Calculate channel index based on the quantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (tidx.x) + // axis 1 -> H dimension (tidx.y) + // axis 2 -> C dimension (tidx.z) + // axis 3 -> N dimension (tidx.w) + int channel_idx = 0; + + if (axis == 0) { + channel_idx = out_tidx.x; + } else if (axis == 1) { + channel_idx = out_tidx.y; + } else if (axis == 2) { + channel_idx = out_tidx.z; + } else if (axis == 3) { + channel_idx = out_tidx.w; + } + + channel_idx = min(channel_idx, num_channels - 1); + + OUT_T qvalue = quantize_val(value, t_scale[channel_idx], t_zero_point[channel_idx]); + + t_out[out_bufi] = qvalue; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 4d95d610314..1dd8e6e2ffe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -17,3 +17,5 @@ quantize_buffer: MODE: per_tensor - NAME: quantize_per_token_buffer MODE: per_token + - NAME: quantize_per_channel_buffer + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl index 9ba7074f75b..bdaba3ffaf9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -26,6 +26,8 @@ ${define_required_extensions(OUT_DTYPE)} layout(std430) buffer; +#include "indexing_utils.h" + ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} @@ -45,11 +47,23 @@ $if MODE == "per_token": int quant_min; int quant_max; }; +$if MODE == "per_channel": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int axis; + int num_channels; + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} -#include "indexing_utils.h" +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + #include "quantize.glslh" layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -138,7 +152,7 @@ void quantize_per_tensor() { write_texel(t_out, pos, outtex); } -#else +#elif defined(per_token) void quantize_per_token() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -177,6 +191,84 @@ void quantize_per_token() { write_texel(t_out, pos, outtex); } +#else // per_channel + +void quantize_per_channel() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + // Calculate channel index based on the quantization axis (already converted to WHCN) + // The axis parameter is now in WHCN coordinate system: + // axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component) + // axis 1 -> H dimension (pos.y) + // axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding + // axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding + + if (axis == 0) { + // Width dimension - each texel component has different channel index + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + int channel_idx = pos.x * 4 + i; + channel_idx = min(channel_idx, num_channels - 1); + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } else if (axis == 1) { + // Height dimension - all texel components use same channel index + int channel_idx = pos.y; + channel_idx = min(channel_idx, num_channels - 1); + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } else if (axis == 2) { + // Channel dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual channel index from the folded dimension + int folded_idx = pos.z; + int channel_idx = folded_idx % num_channels; + + float scale_val = t_scale[channel_idx]; + int zero_point_val = t_zero_point[channel_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } else if (axis == 3) { + // Batch dimension - for 4D tensors, need to account for batch-channel folding + // The Z coordinate contains folded batch*channel information + // We need to extract the actual batch index from the folded dimension + int folded_idx = pos.z; + int batch_idx = folded_idx / num_channels; + + float scale_val = t_scale[batch_idx]; + int zero_point_val = t_zero_point[batch_idx]; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 65002ce26b6..47e532be8b9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -17,3 +17,5 @@ quantize_texture: MODE: per_tensor - NAME: quantize_per_token_texture3d MODE: per_token + - NAME: quantize_per_channel_texture3d + MODE: per_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 1dc2d34afbf..5e9599b91e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -150,6 +150,7 @@ void add_choose_qparams_tensor_node( const ValueRef& input, const ValueRef& quant_min, const ValueRef& quant_max, + const ValueRef& eps, const ValueRef& scale_out, const ValueRef& zero_point_out) { std::string kernel_name("choose_qparams_tensor"); @@ -158,6 +159,7 @@ void add_choose_qparams_tensor_node( int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; @@ -180,6 +182,7 @@ void add_choose_qparams_tensor_node( push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float)), }; graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -275,8 +278,22 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused dtype parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -289,13 +306,10 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -303,7 +317,7 @@ void choose_qparams_tensor_impl( } add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, scale_out, zero_point_out); + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); } void choose_qparams_per_token_asymmetric_impl( @@ -311,8 +325,21 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -325,22 +352,20 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } REGISTER_OPERATORS { - VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( - choose_qparams_per_token_asymmetric.default, + quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 77a51ce24f9..3838da9a151 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -180,8 +180,15 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -212,8 +219,15 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -257,18 +271,34 @@ void dequantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_dequantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); - VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_token.default, + dequantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 49277b4d718..74dee249b0a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -12,11 +12,10 @@ #include #include -#include -namespace vkcompute { +#include -namespace { +namespace vkcompute { void resize_quantize_output( ComputeGraph* graph, @@ -28,7 +27,52 @@ void resize_quantize_output( graph->virtual_resize(out, graph->sizes_of(in)); } -} // namespace +utils::uvec3 quantize_per_channel_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + + return global_wg_size; +} + +utils::uvec3 quantize_per_channel_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)args; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel quantization along the batch axis, + // we need to ensure that we dispatch the correct number of workgroups in the + // Z dimension to cover all batch-channel combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + local_wg_size[2] = 1; + } + + return local_wg_size; +} void add_quantize_per_tensor_node( ComputeGraph& graph, @@ -171,6 +215,99 @@ void add_quantize_per_token_node( resize_quantize_output)); } +void add_quantize_per_channel_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& axis, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_channel"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int axis_val = static_cast(graph.get_int(axis)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + // Normalize axis and convert from NCHW to WHCN using utility functions + const auto input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + // Normalize axis to handle negative indices + axis_val = normalize(axis_val, ndim); + + // Convert from NCHW axis to WHCN axis for shader (vulkan representation) + int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); + + int num_channels; + if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { + // For batch dimension quantization in 4D tensors, pass the actual number of + // channels so the shader can correctly unfold the batch-channel folding + num_channels = static_cast(input_sizes[1]); // Channel dimension + } else { + num_channels = static_cast(input_sizes[axis_val]); + } + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + quantize_per_channel_global_wg_size, + quantize_per_channel_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + void quantize_per_tensor_impl( ComputeGraph& graph, const std::vector& args) { @@ -180,8 +317,12 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -205,8 +346,12 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -243,18 +388,114 @@ void quantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_quantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } +void quantize_per_channel_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef axis = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Normalize axis + int axis_val = static_cast(graph.get_int(axis)); + const auto input_sizes = graph.sizes_of(input); + int64_t ndim = graph.dim_of(input); + if (axis_val < 0) { + axis_val += ndim; + } + + // Verify axis is valid + VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); + + // Get number of channels along the specified axis + int64_t num_channels = input_sizes[axis_val]; + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_channels + VK_CHECK_COND(scale_numel == num_channels); + VK_CHECK_COND(zero_point_numel == num_channels); + + add_quantize_per_channel_node( + graph, input, scale, zero_point, axis, quant_min, quant_max, output); +} + REGISTER_OPERATORS { - VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); - VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_channel.default, + quantize_per_channel_impl); } } // namespace vkcompute diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index f112581c498..99ba6a86594 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -18,6 +18,8 @@ enum VkDataType : byte { INT32 = 3, FLOAT16 = 4, FLOAT32 = 5, + FLOAT64 = 6, + INT64 = 7, } // Describes what kind of GPU resource should be used to represent a tensor. The diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 5bae0475c28..cd876bd6305 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -45,9 +45,11 @@ def __init__( self, program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, + downcast_64_bit: bool = True, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder + self.downcast_64_bit = downcast_64_bit self.chain = [] self.values = [] self.input_ids = [] @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: return vk_graph_schema.VkDataType.INT8 elif torch_dtype == torch.int32: return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT64 elif torch_dtype == torch.float16: return vk_graph_schema.VkDataType.FLOAT16 elif torch_dtype == torch.float32: return vk_graph_schema.VkDataType.FLOAT32 - # Narrowing conversion for index tensor produced by max_poolNd_with_indices. - elif torch_dtype == torch.int64: - return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float64: + return vk_graph_schema.VkDataType.FLOAT64 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # pyre-ignore[16] memory_layout = spec.vk_memory_layout + # Apply downcast logic before getting VK datatype + effective_dtype = spec.dtype + if self.downcast_64_bit and spec.dtype == torch.float64: + effective_dtype = torch.float32 + elif self.downcast_64_bit and spec.dtype == torch.int64: + effective_dtype = torch.int32 + + datatype = self.get_vk_datatype(effective_dtype) + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( - datatype=self.get_vk_datatype(spec.dtype), + datatype=datatype, dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 35113bc623a..f845e5601a7 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -29,6 +29,8 @@ class VkDataType(IntEnum): INT32 = 3 FLOAT16 = 4 FLOAT32 = 5 + FLOAT64 = 6 + INT64 = 7 class VkStorageType(IntEnum): diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 55e96151387..75b7cbc8960 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -433,14 +433,23 @@ void test_vulkan_choose_qparams_tensor_impl( const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams.tensor") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add eps and dtype parameters to match ATen signature + const ValueRef r_eps = graph.add_scalar(6.1e-5); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") (graph, { r_input.value, r_quant_min, r_quant_max, - r_scale, - r_zero_point, + r_eps, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); @@ -647,12 +656,20 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl( const ValueRef r_zero_point = graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add dtype parameter to match ATen signature + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN( + "quantized_decomposed.choose_qparams_per_token_asymmetric.default") (graph, { r_input.value, - r_scale, - r_zero_point, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6c604076c41..f32a93e2b6a 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -49,6 +49,17 @@ Tensor& dequantize_per_token_out( ScalarType out_dtype, Tensor& out); +Tensor& dequantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + // Wrapper function for dequantize_per_tensor_out without context Tensor& dequantize_per_tensor_out_no_context( const Tensor& input, @@ -77,6 +88,29 @@ Tensor& dequantize_per_token_out_no_context( input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); } +// Wrapper function for dequantize_per_channel_out without context +Tensor& dequantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_channel_out( + input, + scale, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + // ATen wrapper for dequantize_per_tensor at::Tensor dequantize_per_tensor_aten( const at::Tensor& input, @@ -131,6 +165,36 @@ at::Tensor dequantize_per_token_aten( return out; } +// ATen wrapper for dequantize_per_channel +at::Tensor dequantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) + (input, + scale, + zero_points, + axis, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -183,6 +247,40 @@ void check_dequantize_args( } } +/** + * Helper function to validate dequantize_per_channel arguments + * Similar to the validation in quantize_test.cpp + */ +void check_dequantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -322,6 +420,120 @@ at::Tensor dequantize_per_token_reference_impl( return out; } +/* + * Reference implementation of dequantize_per_channel + */ +at::Tensor dequantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, out_dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = nullptr; + if (zero_point.has_value()) { + zero_point_data = zero_point.value().const_data_ptr(); + } + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = 0; + if (zero_point_data != nullptr) { + channel_zero_point = zero_point_data[channel_idx]; + } + + // Store casted values to avoid repeated casting + const int32_t channel_zero_point_int32 = + static_cast(channel_zero_point); + const float channel_scale_float = static_cast(channel_scale); + + // Get the input value and dequantize + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store the result based on output dtype + if (out_dtype == at::kFloat) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + output.flatten()[flat_idx] = dequantized_value; + } else if (out_dtype == at::kHalf) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_dequantize_per_tensor_impl( const std::vector& input_sizes, @@ -585,7 +797,10 @@ void test_vulkan_dequantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.default") (graph, { r_input.value, @@ -593,6 +808,8 @@ void test_vulkan_dequantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); @@ -620,7 +837,8 @@ void test_vulkan_dequantize_per_tensor_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1046,7 +1264,10 @@ void test_vulkan_dequantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") (graph, { r_input.value, @@ -1054,6 +1275,8 @@ void test_vulkan_dequantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); @@ -1095,7 +1318,8 @@ void test_vulkan_dequantize_per_token_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1339,3 +1563,191 @@ TEST( at::kChar, // input dtype at::kDouble); // output dtype } + +void test_reference_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create input tensor with quantized values + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = dequantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(my_ref, cpu_ref); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_dequantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, + at::kFloat); +} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 150bda6989e..ebb12bc1b3a 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -48,6 +48,16 @@ Tensor& quantize_per_token_out( ScalarType dtype, Tensor& out); +Tensor& quantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + // Wrapper function for quantize_per_tensor_out without context Tensor& quantize_per_tensor_out_no_context( const Tensor& input, @@ -74,6 +84,20 @@ Tensor& quantize_per_token_out_no_context( input, scale, zero_point, quant_min, quant_max, dtype, out); } +// Wrapper function for quantize_per_channel_out without context +Tensor& quantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_channel_out( + input, scale, zero_point, axis, quant_min, quant_max, dtype, out); +} + // ATen wrapper for quantize_per_tensor at::Tensor quantize_per_tensor_aten( const at::Tensor& input, @@ -106,6 +130,23 @@ at::Tensor quantize_per_token_aten( return out; } +// ATen wrapper for quantize_per_channel +at::Tensor quantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) + (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -160,6 +201,40 @@ void check_quantize_args( quant_max); } +/** + * Helper function to validate quantize_per_channel arguments + * Similar to the validation in op_quantize.cpp + */ +void check_quantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -271,6 +346,110 @@ at::Tensor quantize_per_token_reference_impl( return out; } +/* + * Reference implementation of quantize_per_channel + */ +at::Tensor quantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const float* input_data = input.const_data_ptr(); + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = zero_point.const_data_ptr(); + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = zero_point_data[channel_idx]; + + // Get the input value + float input_value = input_data[flat_idx]; + + // Apply quantization formula: round(input / scale) + zero_point + float inv_scale = 1.0f / static_cast(channel_scale); + int64_t quantized_value = static_cast( + static_cast(channel_zero_point) + + std::nearbyint(static_cast(inv_scale * input_value))); + + // Clamp to quantization bounds + quantized_value = std::max(quantized_value, quant_min); + quantized_value = std::min(quantized_value, quant_max); + + // Store the result based on output dtype + switch (dtype) { + case at::kByte: { + uint8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kChar: { + int8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kShort: { + int16_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kInt: { + int32_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + default: + assert(false && "Unsupported output dtype"); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_quantize_per_tensor_impl( const std::vector& input_sizes, @@ -294,6 +473,18 @@ void test_vulkan_quantize_per_token_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_quantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_quantize_per_tensor( const std::vector& input_sizes, @@ -374,6 +565,48 @@ void test_vulkan_quantize_per_token( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + + test_vulkan_quantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_quantize_per_tensor( const std::vector& input_sizes, float scale, @@ -476,7 +709,10 @@ void test_vulkan_quantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.default") (graph, { r_input.value, @@ -484,6 +720,7 @@ void test_vulkan_quantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -509,7 +746,10 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -835,7 +1075,10 @@ void test_vulkan_quantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") (graph, { r_input.value, @@ -843,6 +1086,7 @@ void test_vulkan_quantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -881,7 +1125,10 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -916,7 +1163,7 @@ void test_vulkan_quantize_per_token_impl( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -932,7 +1179,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -948,7 +1195,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -964,7 +1211,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_uint8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -980,7 +1227,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1001,9 +1248,7 @@ TEST( at::kByte); } -TEST( - VulkanQuantizePerTensorTest, - test_vulkan_quantize_per_token_float_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_int8_buffers_support()) { @@ -1024,7 +1269,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32) { std::vector scales = { -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; @@ -1041,7 +1286,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32_small_scales) { std::vector scales = { 0, @@ -1062,7 +1307,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1087,7 +1332,7 @@ TEST( at::kByte); } -TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_float16_buffers_support()) { @@ -1107,7 +1352,7 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_double_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1126,3 +1371,760 @@ TEST( at::kDouble, // input dtype at::kChar); // output dtype } + +void test_reference_quantize_per_channel( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = quantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get direct ATen implementation output + c10::ScalarType aten_dtype = dtype; + if (dtype == at::kChar) { + aten_dtype = c10::kQInt8; + } else if (dtype == at::kByte) { + aten_dtype = c10::kQUInt8; + } + + // Normalize axis for ATen (it doesn't handle negative values) + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + at::Tensor aten_ref = at::quantize_per_channel( + input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor my_ref_int = my_ref.to(at::kInt); + at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); + // For quantized tensors, we need to use int_repr() to get the underlying + // integer values + at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); + + const bool output_correct = at::equal(my_ref_int, cpu_ref_int); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "aten_ref:" << std::endl; + std::cout << aten_ref_int << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref_int << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::quantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_axis = graph.add_scalar(axis); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_channel.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_axis, + r_quant_min, + r_quant_max, + r_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_quantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +// END OF REFERENCE TESTS + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis0) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(9, 0.1f); + std::vector zero_points(9, 2); + + // 1D Tensor + test_vulkan_quantize_per_channel( + {9}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 2D Tensor + test_vulkan_quantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 3D Tensor + test_vulkan_quantize_per_channel( + {9, 7, 11}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 17, 5, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 17, 5, 9}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis1) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(14, 0.001f); + std::vector zero_points(14, -5); + + // 2D Tensor + test_vulkan_quantize_per_channel( + {9, 14}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 3D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 5}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {9, 7, 14, 5}, // input sizes + scales, + zero_points, + -2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis2) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(11, 0.5f); + std::vector zero_points(11, 12); + + // 3D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {9, 11, 14, 5}, // input sizes + scales, + zero_points, + -3, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis3) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(7, 0.5f); + std::vector zero_points(7, 12); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 7}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {7, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; + std::vector zero_points = {0, 5, -5, 1, 12}; + + // 4D Tensor + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_half_to_8bit) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kHalf, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kHalf, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kHalf, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} + +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_double_to_8bit) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; + std::vector zero_points = {0, 5, 5, 1, 12}; + + // 4D Tensor + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kDouble, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 5, 11, 7}, // input sizes + scales, + zero_points, + 1, // axis + -128, // quant_min + 127, // quant_max + at::kDouble, + at::kChar); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 5, 7}, // input sizes + scales, + zero_points, + 2, // axis + 0, // quant_min + 255, // quant_max + at::kDouble, + at::kByte); + + // 4D Tensor + test_vulkan_quantize_per_channel( + {9, 14, 11, 5}, // input sizes + scales, + zero_points, + 3, // axis + -128, // quant_min + 127, // quant_max + at::kDouble, + at::kChar); + + // 4D Tensor (negative axis) + test_vulkan_quantize_per_channel( + {5, 14, 11, 7}, // input sizes + scales, + zero_points, + -4, // axis + 0, // quant_min + 255, // quant_max + at::kDouble, + at::kByte); +} diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a22afc3f42e..a6d5737dbb8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -67,7 +67,6 @@ # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): new_gm = program.graph_module # This is a workaround to allow the memory planning pass to work without @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "skip_tag_memory_metadata": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "downcast_64_bit": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -142,6 +144,7 @@ def preprocess( # noqa: C901 default_memory_layout = compile_options.get( "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) + downcast_64_bit = compile_options.get("downcast_64_bit", True) program = unsafe_remove_auto_functionalized_pass(program) @@ -213,7 +216,9 @@ def preprocess( # noqa: C901 ) graph_builder = VkGraphBuilder( - program, DelegateMappingBuilder(generated_identifiers=True) + program, + DelegateMappingBuilder(generated_identifiers=True), + downcast_64_bit=downcast_64_bit, ) vk_graph = graph_builder.build_graph() diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index cb7b36a5fc1..104531f0fbb 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -155,19 +155,39 @@ struct type_convert< }; // Optionals: ATen to ETen. -template -struct type_convert, torch::executor::optional> final { +template +struct type_convert< + AOptional, + EOptional, + std::enable_if_t< + std::is_same_v< + typename remove_const_ref::type, + std::optional< + typename remove_const_ref::type::value_type>> && + std::is_same_v< + typename remove_const_ref::type, + torch::executor::optional< + typename remove_const_ref::type::value_type>>>> + final { public: - std::optional val; - std::unique_ptr> convert_struct; - explicit type_convert(std::optional value) : val(value) {} - torch::executor::optional call() { + typename remove_const_ref::type val; + std::unique_ptr::type::value_type, + typename remove_const_ref::type::value_type>> + convert_struct; + explicit type_convert(AOptional value) : val(value) {} + typename remove_const_ref::type call() { if (val.has_value()) { - convert_struct = std::make_unique>( - type_convert(val.value())); - return torch::executor::optional(convert_struct->call()); + convert_struct = std::make_unique::type::value_type, + typename remove_const_ref::type::value_type>>( + type_convert< + typename remove_const_ref::type::value_type, + typename remove_const_ref::type::value_type>( + val.value())); + return typename remove_const_ref::type(convert_struct->call()); } else { - return torch::executor::optional(); + return typename remove_const_ref::type(); } } }; diff --git a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp index 17d0f7a4d63..a5b53096ae2 100644 --- a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp +++ b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp @@ -421,3 +421,92 @@ TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) { EXPECT_EQ(stack.size(), 1); EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 4); } + +TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ConstRefOptionals) { + // Test const optional scalar conversion + const std::optional const_optional_at_in = + std::optional(42); + auto const_optional_et = + type_convert< + const std::optional, + torch::executor::optional>(const_optional_at_in) + .call(); + EXPECT_TRUE(const_optional_et.has_value()); + EXPECT_EQ(const_optional_et.value(), 42); + + // Test optional scalar reference conversion + std::optional optional_at_ref_in = std::optional(24); + auto optional_et_from_ref = + type_convert&, torch::executor::optional>( + optional_at_ref_in) + .call(); + EXPECT_TRUE(optional_et_from_ref.has_value()); + EXPECT_EQ(optional_et_from_ref.value(), 24); + + // Test const optional scalar reference conversion + const std::optional const_optional_at_ref_in = + std::optional(84); + auto const_optional_et_from_ref = + type_convert< + const std::optional&, + torch::executor::optional>(const_optional_at_ref_in) + .call(); + EXPECT_TRUE(const_optional_et_from_ref.has_value()); + EXPECT_EQ(const_optional_et_from_ref.value(), 84); + + // Test const optional tensor conversion + const std::optional const_optional_tensor_at_in = + std::optional(torch::tensor({5})); + auto const_optional_tensor_converter = type_convert< + const std::optional, + torch::executor::optional>( + const_optional_tensor_at_in); + auto const_optional_tensor_et = const_optional_tensor_converter.call(); + EXPECT_TRUE(const_optional_tensor_et.has_value()); + EXPECT_EQ(const_optional_tensor_et.value().const_data_ptr()[0], 5); + + // Test optional tensor reference conversion + std::optional optional_tensor_at_ref_in = + std::optional(torch::tensor({7})); + auto optional_tensor_converter_from_ref = type_convert< + std::optional&, + torch::executor::optional>( + optional_tensor_at_ref_in); + auto optional_tensor_et_from_ref = optional_tensor_converter_from_ref.call(); + EXPECT_TRUE(optional_tensor_et_from_ref.has_value()); + EXPECT_EQ( + optional_tensor_et_from_ref.value().const_data_ptr()[0], 7); + + // Test const optional tensor reference conversion + const std::optional const_optional_tensor_at_ref_in = + std::optional(torch::tensor({9})); + auto const_optional_tensor_converter_from_ref = type_convert< + const std::optional&, + torch::executor::optional>( + const_optional_tensor_at_ref_in); + auto const_optional_tensor_et_from_ref = + const_optional_tensor_converter_from_ref.call(); + EXPECT_TRUE(const_optional_tensor_et_from_ref.has_value()); + EXPECT_EQ( + const_optional_tensor_et_from_ref.value().const_data_ptr()[0], + 9); + + // Test empty const optional conversions + const std::optional empty_const_optional_at_in = std::nullopt; + auto empty_const_optional_et = + type_convert< + const std::optional, + torch::executor::optional>(empty_const_optional_at_in) + .call(); + EXPECT_FALSE(empty_const_optional_et.has_value()); + + const std::optional empty_const_optional_tensor_at_in = + std::nullopt; + auto empty_const_optional_tensor_et = + type_convert< + const std::optional, + torch::executor::optional>( + empty_const_optional_tensor_at_in) + .call(); + EXPECT_FALSE(empty_const_optional_tensor_et.has_value()); +} diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index d0b7c882f8e..5586f8a77eb 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include @@ -282,55 +281,34 @@ Tensor& quantize_per_channel_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // a list contains all dimensions except axis - int64_t dims[kTensorDimensionLimit]; - for (int64_t i = 0; i < input.dim() - 1; i++) { - if (i < axis) { - dims[i] = i; - } else { - dims[i] = i - 1; - } - } const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - std::optional> optional_dim_list{ - executorch::aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] + // High-performance single loop with direct channel calculation #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - double _scale = scale_data[channel_ix]; \ - int64_t _zero_point = zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + const int64_t axis_size = input.size(axis); \ + /* Calculate the stride pattern for efficient channel index calculation */ \ + int64_t axis_block_size = 1; \ + for (int64_t i = axis + 1; i < input.dim(); i++) { \ + axis_block_size *= input.size(i); \ } \ - break; + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ + } \ + } break; + #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index 3ba9715506a..f29f1f013b7 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -51,12 +51,6 @@ _QUANT_OPS = ( ), op_target( name = "op_quantize", - deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util", - ], - _aten_mode_deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util_aten", - ], ), ) diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 5cd17223d80..4ac835c24ce 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 2}, 4); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {100, 50, 25}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 2}); + // Channel 0: 4 / 0.5 + 100 = 108 + // Channel 1: 4 / 1.0 + 50 = 54 + // Channel 2: 4 / 2.0 + 25 = 27 + Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27}); + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel3D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 3D tensor with axis=1 (middle dimension) + Tensor input = tf_float.full({2, 3, 4}, 6); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3, 4}); + // Channel 0: 6 / 0.5 + 10 = 22 + // Channel 1: 6 / 1.0 + 20 = 26 + // Channel 2: 6 / 1.5 + 30 = 34 + Tensor expected = tfo.make( + {2, 3, 4}, + { + 22, 22, 22, 22, // First batch, channel 0 + 26, 26, 26, 26, // First batch, channel 1 + 34, 34, 34, 34, // First batch, channel 2 + 22, 22, 22, 22, // Second batch, channel 0 + 26, 26, 26, 26, // Second batch, channel 1 + 34, 34, 34, 34 // Second batch, channel 2 + }); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel4D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W) + Tensor input = tf_float.full({2, 2, 3, 2}, 8); + Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2, 3, 2}); + // Channel 0: 8 / 0.25 + 0 = 32 + // Channel 1: 8 / 0.5 + 10 = 26 + // Channel 2: 8 / 1.0 + 20 = 28 + std::vector expected_data; + for (int n = 0; n < 2; n++) { + for (int c = 0; c < 2; c++) { + for (int h = 0; h < 3; h++) { + for (int w = 0; w < 2; w++) { + int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28; + expected_data.push_back(val); + } + } + } + } + Tensor expected = tfo.make({2, 2, 3, 2}, expected_data); + quantize_per_channel_out( + input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 3}, 5); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Using axis=-1 should be equivalent to axis=1 for 2D tensor + // Channel 0: 5 / 0.5 + 0 = 10 + // Channel 1: 5 / 1.0 + 10 = 15 + // Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5) + Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22}); + quantize_per_channel_out( + input, + scale, + zero_point, + -1, + quant_min, + quant_max, + ScalarType::Byte, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 1, 4}, 7); + Tensor scale = tf_double.make({1}, {0.5}); + Tensor zero_point = tf_long.make({1}, {128}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 1, 4}); + // Single channel: 7 / 0.5 + 128 = 142 + Tensor expected = tfo.full({3, 1, 4}, 142); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) { + TensorFactory tf_double_input; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_double_input.full({2, 2}, 3.14159); + Tensor scale = tf_double.make({2}, {0.01, 0.02}); + Tensor zero_point = tf_long.make({2}, {0, 100}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127 + // Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127 + Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 2}, 10); + Tensor scale = tf_double.make({2}, {1.0, 2.0}); + Tensor zero_point = tf_long.make({2}, {1000, 2000}); + int64_t quant_min = -32768; + int64_t quant_max = 32767; + + // Test with 16-bit output + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 10 / 1.0 + 1000 = 1010 + // Channel 1: 10 / 2.0 + 2000 = 2005 + Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005}); + quantize_per_channel_out( + input, + scale, + zero_point, + 1, + quant_min, + quant_max, + ScalarType::Short, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test with different input values per position + Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32] + // Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34] + Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test values that will exceed quant_min/quant_max bounds + Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0}); + Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 0, 0}); + int64_t quant_min = -10; + int64_t quant_max = 10; + + TensorFactory tfo; + Tensor out = tfo.zeros({1, 3}); + // Values: [-100, 0, 100] should be clamped to [-10, 0, 10] + Tensor expected = tfo.make({1, 3}, {-10, 0, 10}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +}