diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index 66620e9b174..d6d27d2e3a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,15 +9,13 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// equivalent of the eps defined in the cpu implementation -#define SMALL_SCALE_THRESHOLD 6.1e-5 - // Calculate scale and zero point from min and max values void calculate_scale_and_zero_point( float min_val, float max_val, int qmin, int qmax, + float eps_threshold, out float scale_val, out int zero_point_val) { // ensure we have zero included in our range @@ -31,18 +29,18 @@ void calculate_scale_and_zero_point( scale_val = 0.1; } - // Cut off small scale - if (scale_val < SMALL_SCALE_THRESHOLD) { + // Cut off small scale using the provided eps threshold + if (scale_val < eps_threshold) { float org_scale = scale_val; - scale_val = SMALL_SCALE_THRESHOLD; + scale_val = eps_threshold; // Adjust min and max based on new scale if (min_val == 0.0) { - max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + max_val = eps_threshold * float(qmax - qmin); } else if (max_val == 0.0) { - min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + min_val = -eps_threshold * float(qmax - qmin); } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + float amplifier = eps_threshold / org_scale; min_val *= amplifier; max_val *= amplifier; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index dcbfe493f34..48681a46c30 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -29,6 +29,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -175,7 +176,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; @@ -260,7 +261,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val; diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 282f1de170a..5076b2d68e9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -30,6 +30,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -234,7 +235,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); @@ -372,7 +373,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 1dc2d34afbf..5e9599b91e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -150,6 +150,7 @@ void add_choose_qparams_tensor_node( const ValueRef& input, const ValueRef& quant_min, const ValueRef& quant_max, + const ValueRef& eps, const ValueRef& scale_out, const ValueRef& zero_point_out) { std::string kernel_name("choose_qparams_tensor"); @@ -158,6 +159,7 @@ void add_choose_qparams_tensor_node( int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; @@ -180,6 +182,7 @@ void add_choose_qparams_tensor_node( push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float)), }; graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -275,8 +278,22 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused dtype parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -289,13 +306,10 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -303,7 +317,7 @@ void choose_qparams_tensor_impl( } add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, scale_out, zero_point_out); + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); } void choose_qparams_per_token_asymmetric_impl( @@ -311,8 +325,21 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -325,22 +352,20 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } REGISTER_OPERATORS { - VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( - choose_qparams_per_token_asymmetric.default, + quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 77a51ce24f9..3838da9a151 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -180,8 +180,15 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -212,8 +219,15 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -257,18 +271,34 @@ void dequantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_dequantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); - VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_token.default, + dequantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 49277b4d718..f8f930bf0fb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -180,8 +180,12 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -205,8 +209,12 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -243,18 +251,33 @@ void quantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_quantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); - VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 55e96151387..75b7cbc8960 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -433,14 +433,23 @@ void test_vulkan_choose_qparams_tensor_impl( const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams.tensor") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add eps and dtype parameters to match ATen signature + const ValueRef r_eps = graph.add_scalar(6.1e-5); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") (graph, { r_input.value, r_quant_min, r_quant_max, - r_scale, - r_zero_point, + r_eps, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); @@ -647,12 +656,20 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl( const ValueRef r_zero_point = graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add dtype parameter to match ATen signature + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN( + "quantized_decomposed.choose_qparams_per_token_asymmetric.default") (graph, { r_input.value, - r_scale, - r_zero_point, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6c604076c41..82f316abe82 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -585,7 +585,10 @@ void test_vulkan_dequantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.default") (graph, { r_input.value, @@ -593,6 +596,8 @@ void test_vulkan_dequantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); @@ -1046,7 +1051,10 @@ void test_vulkan_dequantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") (graph, { r_input.value, @@ -1054,6 +1062,8 @@ void test_vulkan_dequantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 150bda6989e..64ea144fbf1 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -476,7 +476,10 @@ void test_vulkan_quantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.default") (graph, { r_input.value, @@ -484,6 +487,7 @@ void test_vulkan_quantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -835,7 +839,10 @@ void test_vulkan_quantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") (graph, { r_input.value, @@ -843,6 +850,7 @@ void test_vulkan_quantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, r_out, });