Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ def register_binary_op():
)


@update_features(
[
exir_ops.edge.aten.pow.Tensor_Scalar,
]
)
def register_binary_scalar_op():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)


@update_features(
[
exir_ops.edge.aten.abs.default,
Expand Down
56 changes: 56 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op_defs.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef BINARY_OP_DEFS_GLSLH
#define BINARY_OP_DEFS_GLSLH

//
// Power operation that handles negative and zero bases
//
// In GLSL, pow(x, y) is undefined for x < 0. This function provides
// a safe implementation that:
// - Handles x == 0 (returns 0 for y > 0, returns 1 for y == 0)
// - Handles x < 0 by using absolute value and preserving sign for odd integer exponents
// - Uses standard pow() for x > 0
//

// Scalar overload
T power_of(T x, T y) {
if (x == 0.0) {
// Handle 0^y: 0^0 = 1, 0^y = 0 for y > 0
return (y == 0.0) ? T(1.0) : T(0.0);
}

// Use absolute value to avoid undefined behavior
float result = pow(abs(x), y);

// For negative bases with odd integer exponents, preserve the negative sign
if (x < 0.0) {
float int_y = round(y);
if (abs(y - int_y) < 1e-5 && int(int_y) % 2 == 1) {
result = -result;
}
}

return T(result);
}

#ifdef VEC4_T

// Vector overload
VEC4_T power_of(VEC4_T x, VEC4_T y) {
VEC4_T result;
for (int i = 0; i < 4; i++) {
result[i] = power_of(x[i], y[i]);
}
return result;
}

#endif // VEC4_T

#endif // BINARY_OP_DEFS_GLSLH
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define NAME ${VARIANT_NAME}

#define T ${buffer_scalar_type(DTYPE)}

#define op(X, Y) ${OPERATOR}

${define_active_storage_type(STORAGE)}
${define_required_extensions(DTYPE)}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}

${layout_declare_ubo(B, "float", "scalar_value")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "binary_op_defs.glslh"

void main() {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value)));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

binary_scalar_buffer:
parameter_names_with_default_values:
OPERATOR: power_of(X, Y)
NDIM: 3
DTYPE: float
PACKING: C_packed
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int32
shader_variants:
- NAME: pow_scalar_buffer
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define NAME ${VARIANT_NAME}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define T ${texel_load_component_type(DTYPE, STORAGE)}

#define op(X, Y) ${OPERATOR}

${define_active_storage_type(STORAGE)}
${define_required_extensions(DTYPE)}

layout(std430) buffer;

#define DEBUG_MODE
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "inp")}

${layout_declare_ubo(B, "float", "scalar_value")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "binary_op_defs.glslh"

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (out_of_bounds(pos, outp)) {
return;
}

VEC4_T in_texel = texelFetch(t_in, pos, 0);
VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value)));

imageStore(t_out, pos, out_texel);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

binary_scalar_texture:
parameter_names_with_default_values:
OPERATOR: power_of(X, Y)
NDIM: 3
DTYPE: float
PACKING: C_packed
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int32
shader_variants:
- NAME: pow_scalar_texture3d
83 changes: 83 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/BinaryScalarOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void resize_binary_scalar_op_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef in = args.at(1).refs.at(0);

const std::vector<int64_t> in_sizes = graph->sizes_of(in);

graph->virtual_resize(out, in_sizes);
}

void add_binary_scalar_op_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef scalar,
const ValueRef out,
const std::string& op_name) {
ValueRef arg = prepack_standard_like(graph, in, out, true);

// Extract scalar value
float scalar_val = graph.extract_scalar<float>(scalar);

// Pick shader
std::string kernel_name = op_name + "_scalar";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(in));

vkapi::ParamsBindList param_ubos = {
graph.meta_ubo(out),
graph.meta_ubo(in),
graph.create_params_buffer(scalar_val)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {arg, vkapi::kRead}},
// Shader params buffers
param_ubos,
// Push Constants
{},
// Specialization Constants
{},
// Resize Args
{},
// Resizing Logic
resize_binary_scalar_op_node));
}

void pow_tensor_scalar(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow");
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar);
}

} // namespace vkcompute
25 changes: 25 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,3 +1947,28 @@ def get_where_inputs():
test_suite.atol = "1e-4"
test_suite.rtol = "1e-4"
return test_suite


@register_test_suite("aten.pow.Tensor_Scalar")
def get_pow_tensor_scalar_inputs():
test_suite = VkTestSuite(
[
((M1,), 2.0),
((M2, M1), 2.0),
((S1, M1, M2), 0.5),
((S1, S2, S2, M2), 2.5),
((S, S1, S2), -1.0),
((M1, M2), 4.0),
((S1, S2), 1.5),
]
)
test_suite.storage_types = [
"utils::kBuffer",
"utils::kTexture3D",
]
test_suite.layouts = [
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.dtypes = ["at::kFloat"]
return test_suite
Loading