Skip to content

Commit 6ea6fcc

Browse files
authored
[ET-VK] Implementation of pow.Tensor_Scalar
Differential Revision: D84716452 Pull Request resolved: #15159
1 parent f1fe4d5 commit 6ea6fcc

File tree

8 files changed

+312
-0
lines changed

8 files changed

+312
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,18 @@ def register_binary_op():
228228
)
229229

230230

231+
@update_features(
232+
[
233+
exir_ops.edge.aten.pow.Tensor_Scalar,
234+
]
235+
)
236+
def register_binary_scalar_op():
237+
return OpFeatures(
238+
inputs_storage=utils.ANY_STORAGE,
239+
supports_resize=True,
240+
)
241+
242+
231243
@update_features(
232244
[
233245
exir_ops.edge.aten.abs.default,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef BINARY_OP_DEFS_GLSLH
10+
#define BINARY_OP_DEFS_GLSLH
11+
12+
//
13+
// Power operation that handles negative and zero bases
14+
//
15+
// In GLSL, pow(x, y) is undefined for x < 0. This function provides
16+
// a safe implementation that:
17+
// - Handles x == 0 (returns 0 for y > 0, returns 1 for y == 0)
18+
// - Handles x < 0 by using absolute value and preserving sign for odd integer exponents
19+
// - Uses standard pow() for x > 0
20+
//
21+
22+
// Scalar overload
23+
T power_of(T x, T y) {
24+
if (x == 0.0) {
25+
// Handle 0^y: 0^0 = 1, 0^y = 0 for y > 0
26+
return (y == 0.0) ? T(1.0) : T(0.0);
27+
}
28+
29+
// Use absolute value to avoid undefined behavior
30+
float result = pow(abs(x), y);
31+
32+
// For negative bases with odd integer exponents, preserve the negative sign
33+
if (x < 0.0) {
34+
float int_y = round(y);
35+
if (abs(y - int_y) < 1e-5 && int(int_y) % 2 == 1) {
36+
result = -result;
37+
}
38+
}
39+
40+
return T(result);
41+
}
42+
43+
#ifdef VEC4_T
44+
45+
// Vector overload
46+
VEC4_T power_of(VEC4_T x, VEC4_T y) {
47+
VEC4_T result;
48+
for (int i = 0; i < 4; i++) {
49+
result[i] = power_of(x[i], y[i]);
50+
}
51+
return result;
52+
}
53+
54+
#endif // VEC4_T
55+
56+
#endif // BINARY_OP_DEFS_GLSLH
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define NAME ${VARIANT_NAME}
14+
15+
#define T ${buffer_scalar_type(DTYPE)}
16+
17+
#define op(X, Y) ${OPERATOR}
18+
19+
${define_active_storage_type(STORAGE)}
20+
${define_required_extensions(DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
#include "indexing.glslh"
25+
26+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
27+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
28+
29+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
30+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
31+
32+
${layout_declare_ubo(B, "float", "scalar_value")}
33+
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
#include "binary_op_defs.glslh"
37+
38+
void main() {
39+
const uint out_bufi = gl_GlobalInvocationID.x;
40+
if (out_of_bounds(out_bufi, outp)) {
41+
return;
42+
}
43+
44+
t_out[out_bufi] = T(op(t_in[out_bufi], T(scalar_value)));
45+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
binary_scalar_buffer:
8+
parameter_names_with_default_values:
9+
OPERATOR: power_of(X, Y)
10+
NDIM: 3
11+
DTYPE: float
12+
PACKING: C_packed
13+
STORAGE: buffer
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
- VALUE: int32
19+
shader_variants:
20+
- NAME: pow_scalar_buffer
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define NAME ${VARIANT_NAME}
14+
15+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
16+
#define T ${texel_load_component_type(DTYPE, STORAGE)}
17+
18+
#define op(X, Y) ${OPERATOR}
19+
20+
${define_active_storage_type(STORAGE)}
21+
${define_required_extensions(DTYPE)}
22+
23+
layout(std430) buffer;
24+
25+
#define DEBUG_MODE
26+
#include "indexing.glslh"
27+
28+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
29+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
30+
31+
${layout_declare_ubo(B, "TextureMetadata", "outp")}
32+
${layout_declare_ubo(B, "TextureMetadata", "inp")}
33+
34+
${layout_declare_ubo(B, "float", "scalar_value")}
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
#include "binary_op_defs.glslh"
39+
40+
void main() {
41+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
42+
43+
if (out_of_bounds(pos, outp)) {
44+
return;
45+
}
46+
47+
VEC4_T in_texel = texelFetch(t_in, pos, 0);
48+
VEC4_T out_texel = VEC4_T(op(in_texel, VEC4_T(scalar_value)));
49+
50+
imageStore(t_out, pos, out_texel);
51+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
binary_scalar_texture:
8+
parameter_names_with_default_values:
9+
OPERATOR: power_of(X, Y)
10+
NDIM: 3
11+
DTYPE: float
12+
PACKING: C_packed
13+
STORAGE: texture3d
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
- VALUE: int32
19+
shader_variants:
20+
- NAME: pow_scalar_texture3d
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18+
19+
namespace vkcompute {
20+
21+
void resize_binary_scalar_op_node(
22+
ComputeGraph* graph,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& resize_args) {
25+
(void)resize_args;
26+
const ValueRef out = args.at(0).refs.at(0);
27+
const ValueRef in = args.at(1).refs.at(0);
28+
29+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
30+
31+
graph->virtual_resize(out, in_sizes);
32+
}
33+
34+
void add_binary_scalar_op_node(
35+
ComputeGraph& graph,
36+
const ValueRef in,
37+
const ValueRef scalar,
38+
const ValueRef out,
39+
const std::string& op_name) {
40+
ValueRef arg = prepack_standard_like(graph, in, out, true);
41+
42+
// Extract scalar value
43+
float scalar_val = graph.extract_scalar<float>(scalar);
44+
45+
// Pick shader
46+
std::string kernel_name = op_name + "_scalar";
47+
kernel_name.reserve(kShaderNameReserve);
48+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
49+
add_dtype_suffix(kernel_name, graph.dtype_of(in));
50+
51+
vkapi::ParamsBindList param_ubos = {
52+
graph.meta_ubo(out),
53+
graph.meta_ubo(in),
54+
graph.create_params_buffer(scalar_val)};
55+
56+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
57+
graph,
58+
VK_KERNEL_FROM_STR(kernel_name),
59+
default_pick_global_wg_size,
60+
default_pick_local_wg_size,
61+
// Inputs and Outputs
62+
{{out, vkapi::kWrite}, {arg, vkapi::kRead}},
63+
// Shader params buffers
64+
param_ubos,
65+
// Push Constants
66+
{},
67+
// Specialization Constants
68+
{},
69+
// Resize Args
70+
{},
71+
// Resizing Logic
72+
resize_binary_scalar_op_node));
73+
}
74+
75+
void pow_tensor_scalar(ComputeGraph& graph, const std::vector<ValueRef>& args) {
76+
return add_binary_scalar_op_node(graph, args[0], args[1], args[2], "pow");
77+
}
78+
79+
REGISTER_OPERATORS {
80+
VK_REGISTER_OP(aten.pow.Tensor_Scalar, pow_tensor_scalar);
81+
}
82+
83+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,3 +1947,28 @@ def get_where_inputs():
19471947
test_suite.atol = "1e-4"
19481948
test_suite.rtol = "1e-4"
19491949
return test_suite
1950+
1951+
1952+
@register_test_suite("aten.pow.Tensor_Scalar")
1953+
def get_pow_tensor_scalar_inputs():
1954+
test_suite = VkTestSuite(
1955+
[
1956+
((M1,), 2.0),
1957+
((M2, M1), 2.0),
1958+
((S1, M1, M2), 0.5),
1959+
((S1, S2, S2, M2), 2.5),
1960+
((S, S1, S2), -1.0),
1961+
((M1, M2), 4.0),
1962+
((S1, S2), 1.5),
1963+
]
1964+
)
1965+
test_suite.storage_types = [
1966+
"utils::kBuffer",
1967+
"utils::kTexture3D",
1968+
]
1969+
test_suite.layouts = [
1970+
"utils::kWidthPacked",
1971+
"utils::kChannelsPacked",
1972+
]
1973+
test_suite.dtypes = ["at::kFloat"]
1974+
return test_suite

0 commit comments

Comments
 (0)