From 29ba2ab76ef3f0b597b2d38a84797aa7c9e2078f Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 21 Oct 2025 17:20:03 -0400 Subject: [PATCH] [ET-VK] Implementation of `pow.Tensor_Scalar` Pull Request resolved: https://github.com/pytorch/executorch/pull/15159 Title says it all! ghstack-source-id: 317683542 @exported-using-ghexport Differential Revision: [D84716452](https://our.internmc.facebook.com/intern/diff/D84716452/) --- backends/vulkan/op_registry.py | 12 +++ .../graph/ops/glsl/binary_op_defs.glslh | 56 +++++++++++++ .../graph/ops/glsl/binary_scalar_buffer.glsl | 45 ++++++++++ .../graph/ops/glsl/binary_scalar_buffer.yaml | 20 +++++ .../graph/ops/glsl/binary_scalar_texture.glsl | 51 ++++++++++++ .../graph/ops/glsl/binary_scalar_texture.yaml | 20 +++++ .../runtime/graph/ops/impl/BinaryScalarOp.cpp | 83 +++++++++++++++++++ backends/vulkan/test/op_tests/cases.py | 25 ++++++ 8 files changed, 312 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 63b57a0e79c..b47a8f383a0 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -228,6 +228,18 @@ def register_binary_op(): ) +@update_features( + [ + exir_ops.edge.aten.pow.Tensor_Scalar, + ] +) +def register_binary_scalar_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + ) + + @update_features( [ exir_ops.edge.aten.abs.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh b/backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh new file mode 100644 index 00000000000..e2bdec703ca --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef BINARY_OP_DEFS_GLSLH +#define BINARY_OP_DEFS_GLSLH + +// +// Power operation that handles negative and zero bases +// +// In GLSL, pow(x, y) is undefined for x < 0. This function provides +// a safe implementation that: +// - Handles x == 0 (returns 0 for y > 0, returns 1 for y == 0) +// - Handles x < 0 by using absolute value and preserving sign for odd integer exponents +// - Uses standard pow() for x > 0 +// + +// Scalar overload +T power_of(T x, T y) { + if (x == 0.0) { + // Handle 0^y: 0^0 = 1, 0^y = 0 for y > 0 + return (y == 0.0) ? T(1.0) : T(0.0); + } + + // Use absolute value to avoid undefined behavior + float result = pow(abs(x), y); + + // For negative bases with odd integer exponents, preserve the negative sign + if (x < 0.0) { + float int_y = round(y); + if (abs(y - int_y) < 1e-5 && int(int_y) % 2 == 1) { + result = -result; + } + } + + return T(result); +} + +#ifdef VEC4_T + +// Vector overload +VEC4_T power_of(VEC4_T x, VEC4_T y) { + VEC4_T result; + for (int i = 0; i < 4; i++) { + result[i] = power_of(x[i], y[i]); + } + return result; +} + +#endif // VEC4_T + +#endif // BINARY_OP_DEFS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl new file mode 100644 index 00000000000..860050bcfb6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.glsl @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define NAME ${VARIANT_NAME} + +#define T ${buffer_scalar_type(DTYPE)} + +#define op(X, Y) ${OPERATOR} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +${layout_declare_ubo(B, "float", "scalar_value")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "binary_op_defs.glslh" + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value))); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml new file mode 100644 index 00000000000..b818132cf9b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_buffer.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_scalar_buffer: + parameter_names_with_default_values: + OPERATOR: power_of(X, Y) + NDIM: 3 + DTYPE: float + PACKING: C_packed + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: pow_scalar_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl new file mode 100644 index 00000000000..971f66f93e5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define NAME ${VARIANT_NAME} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define op(X, Y) ${OPERATOR} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#define DEBUG_MODE +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + +${layout_declare_ubo(B, "float", "scalar_value")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "binary_op_defs.glslh" + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(pos, outp)) { + return; + } + + VEC4_T in_texel = texelFetch(t_in, pos, 0); + VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value))); + + imageStore(t_out, pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml new file mode 100644 index 00000000000..3e731bf7a15 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_scalar_texture: + parameter_names_with_default_values: + OPERATOR: power_of(X, Y) + NDIM: 3 + DTYPE: float + PACKING: C_packed + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: pow_scalar_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp new file mode 100644 index 00000000000..a6a6182ad2e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +namespace vkcompute { + +void resize_binary_scalar_op_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + + const std::vector in_sizes = graph->sizes_of(in); + + graph->virtual_resize(out, in_sizes); +} + +void add_binary_scalar_op_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef scalar, + const ValueRef out, + const std::string& op_name) { + ValueRef arg = prepack_standard_like(graph, in, out, true); + + // Extract scalar value + float scalar_val = graph.extract_scalar(scalar); + + // Pick shader + std::string kernel_name = op_name + "_scalar"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), + graph.meta_ubo(in), + graph.create_params_buffer(scalar_val)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {arg, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_binary_scalar_op_node)); +} + +void pow_tensor_scalar(ComputeGraph& graph, const std::vector& args) { + return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8c5d0c4797b..84926d8f080 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1947,3 +1947,28 @@ def get_where_inputs(): test_suite.atol = "1e-4" test_suite.rtol = "1e-4" return test_suite + + +@register_test_suite("aten.pow.Tensor_Scalar") +def get_pow_tensor_scalar_inputs(): + test_suite = VkTestSuite( + [ + ((M1,), 2.0), + ((M2, M1), 2.0), + ((S1, M1, M2), 0.5), + ((S1, S2, S2, M2), 2.5), + ((S, S1, S2), -1.0), + ((M1, M2), 4.0), + ((S1, S2), 1.5), + ] + ) + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.dtypes = ["at::kFloat"] + return test_suite