Skip to content

Commit 9777fb3

Browse files
authored
[ET-VK] Implement expand (#13692)
Title says it all! Adding high dimension tensor support to permute by using the new `BufferMetadata` struct in the permute shader. Differential Revision: [D80962719](https://our.internmc.facebook.com/intern/diff/D80962719/) [ghstack-poisoned]
1 parent fbd7fb5 commit 9777fb3

File tree

6 files changed

+180
-0
lines changed

6 files changed

+180
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,11 @@ def register_view_ops_with_buffer_meta():
515515
)
516516

517517

518+
@update_features(exir_ops.edge.aten.expand_copy.default)
519+
def register_expand():
520+
return OpFeatures(inputs_storage=utils.ANY_BUFFER, supports_resize=False)
521+
522+
518523
# Fully featured transfer operators (i.e. operators that copy data from the input
519524
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
520525
# for both texture and buffer storage types.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_required_extensions(DTYPE)}
17+
18+
layout(std430) buffer;
19+
20+
#include "indexing.glslh"
21+
22+
${layout_declare_tensor(B, "w", "t_outp", DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "r", "t_inp", DTYPE, "buffer")}
24+
25+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
26+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
27+
28+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
29+
30+
void main() {
31+
const uint outp_bufi = gl_GlobalInvocationID.x;
32+
if (outp_bufi >= numel(outp)) {
33+
return;
34+
}
35+
36+
TensorIndex outp_tidx;
37+
linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx);
38+
39+
// Map output tensor index to input tensor index by taking modulo
40+
// with input tensor sizes for each dimension
41+
TensorIndex inp_tidx = outp_tidx;
42+
for (int d = 0; d < ndim(inp); ++d) {
43+
uint inp_size = size_at(inp, d);
44+
uint outp_idx = idx_at(outp_tidx, d);
45+
inp_tidx.data[div_4(d)][mod_4(d)] = outp_idx % inp_size;
46+
}
47+
48+
const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx);
49+
// Copy data from input to output
50+
t_outp[outp_bufi] = t_inp[inp_bufi];
51+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
expand_buffer:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
generate_variant_forall:
5+
DTYPE:
6+
- VALUE: half
7+
- VALUE: float
8+
- VALUE: int32
9+
shader_variants:
10+
- NAME: expand_buffer
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
namespace vkcompute {
18+
19+
void add_expand_buffer_node(
20+
ComputeGraph& graph,
21+
const ValueRef in,
22+
const ValueRef size,
23+
const ValueRef out) {
24+
std::string kernel_name = "expand";
25+
kernel_name.reserve(kShaderNameReserve);
26+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
27+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
28+
29+
vkapi::ParamsBindList param_buffers = {
30+
graph.buffer_meta_ubo(out),
31+
graph.buffer_meta_ubo(in),
32+
};
33+
34+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
35+
graph,
36+
VK_KERNEL_FROM_STR(kernel_name),
37+
default_pick_global_wg_size,
38+
default_pick_local_wg_size,
39+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
40+
// Parameter buffers
41+
param_buffers,
42+
// Push Constants
43+
{},
44+
// Specialization Constants
45+
{},
46+
// Resize Args
47+
{size},
48+
// Resizing Logic
49+
nullptr));
50+
}
51+
52+
void expand(ComputeGraph& graph, const std::vector<ValueRef>& args) {
53+
int idx = 0;
54+
const ValueRef in = args.at(idx++);
55+
const ValueRef size = args.at(idx++);
56+
const ValueRef implicit = args.at(idx++);
57+
(void)implicit;
58+
const ValueRef out = args.at(idx++);
59+
60+
if (graph.is_buffer_storage(out)) {
61+
return add_expand_buffer_node(graph, in, size, out);
62+
}
63+
64+
VK_THROW("Expand operator only supports buffer storage");
65+
}
66+
67+
REGISTER_OPERATORS {
68+
VK_REGISTER_OP(aten.expand_copy.default, expand);
69+
}
70+
71+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,48 @@ def get_flip_inputs():
18801880
return test_suite
18811881

18821882

1883+
@register_test_suite("aten.expand_copy.default")
1884+
def get_expand_inputs():
1885+
test_suite = VkTestSuite(
1886+
[
1887+
# Basic expansion cases
1888+
((1,), [5]),
1889+
((1, 1), [3, 4]),
1890+
((1, 3), [2, 3]),
1891+
((3, 1), [3, 4]),
1892+
((1, 1, 1), [2, 3, 4]),
1893+
# Expand with same size (no-op)
1894+
((3, 4), [3, 4]),
1895+
((2, 3, 4), [2, 3, 4]),
1896+
# Expand with additional dimensions
1897+
((3,), [2, 3]),
1898+
((3, 4), [2, 3, 4]),
1899+
((2, 3), [1, 2, 3]),
1900+
# Mixed expansion cases
1901+
((1, 3, 1, 4), [2, 3, 5, 4]),
1902+
((1, 1, 3, 1), [2, 4, 3, 5]),
1903+
# Larger tensor cases
1904+
((1, S1), [M, S1]),
1905+
((S2, 1), [S2, M1]),
1906+
((1, 1, S), [S1, S2, S]),
1907+
((1, S1, 1, S2), [M, S1, M1, S2]),
1908+
]
1909+
)
1910+
test_suite.storage_types = [
1911+
"utils::kBuffer",
1912+
]
1913+
test_suite.layouts = [
1914+
"utils::kWidthPacked",
1915+
"utils::kChannelsPacked",
1916+
]
1917+
test_suite.dtypes = [
1918+
"at::kFloat",
1919+
"at::kHalf",
1920+
]
1921+
test_suite.data_gen = "make_seq_tensor"
1922+
return test_suite
1923+
1924+
18831925
@register_test_suite("aten.where.self")
18841926
def get_where_inputs():
18851927
Test = namedtuple("Where", ["condition", "self", "other"])

backends/vulkan/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def make_filtered_tensor_repset(
621621
CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED})
622622

623623
ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts)
624+
ANY_BUFFER = TensorRepSet(all_memory_layouts, set())
624625

625626
ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts)
626627
NO_STORAGE = TensorRepSet(set(), set())

0 commit comments

Comments
 (0)