Skip to content

Commit e2e2129

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Implement repeat_interleave (pytorch#5830)
Summary: Pull Request resolved: pytorch#5830 As title; implement the `repeat_interleave` operator. The current implementation has some limitations which are documented in the code. ghstack-source-id: 246028695 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D63790717 fbshipit-source-id: 090c9fc77d160619def1d0a2acd01d88185a311e
1 parent 79b7896 commit e2e2129

File tree

6 files changed

+228
-0
lines changed

6 files changed

+228
-0
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,19 @@ class ComputeGraph final {
378378
return values_.at(idx).toString();
379379
}
380380

381+
template <
382+
typename T,
383+
typename std::enable_if<
384+
std::is_integral<T>::value && std::is_signed<T>::value,
385+
int>::type = 0>
386+
T extract_whcn_dim(const ValueRef idx, const int64_t ndim) {
387+
T dim = extract_scalar<T>(idx);
388+
// Normalize dim to account for negative indexing
389+
dim = (dim % ndim + ndim) % ndim;
390+
// Assume original value is NCHW ordering, obtain the WHCN ordering
391+
return ndim - 1 - dim;
392+
}
393+
381394
//
382395
// Utility functions
383396
//
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
${define_required_extensions(DTYPE)}
16+
17+
layout(std430) buffer;
18+
19+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
20+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
21+
${layout_declare_ubo(B, "ivec3", "tin_limits")}
22+
${layout_declare_ubo(B, "ivec4", "tin_axis_map")}
23+
${layout_declare_ubo(B, "ivec4", "tout_axis_map")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int nrepeats = 1;
28+
layout(constant_id = 4) const int repeat_dim = 1;
29+
30+
#include "indexing_utils.h"
31+
32+
void main() {
33+
const ivec3 tin_lpos = ivec3(gl_GlobalInvocationID);
34+
35+
if (any(greaterThanEqual(tin_lpos, tin_limits))) {
36+
return;
37+
}
38+
39+
const VEC4_T intex = load_texel_lpos(tin, tin_lpos, tin_axis_map);
40+
41+
ivec3 tout_lpos = tin_lpos;
42+
tout_lpos[repeat_dim] *= nrepeats;
43+
44+
for (int i = 0; i < nrepeats; ++i, tout_lpos[repeat_dim]++) {
45+
write_texel_lpos(tout, tout_lpos, intex, tout_axis_map);
46+
}
47+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
repeat_interleave:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: repeat_interleave
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
16+
namespace vkcompute {
17+
18+
void resize_repeat_interleave_node(
19+
ComputeGraph* graph,
20+
const std::vector<ArgGroup>& args,
21+
const std::vector<ValueRef>& extra_args) {
22+
(void)extra_args;
23+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
25+
26+
const int64_t nrepeats = graph->extract_scalar<int64_t>(extra_args[0]);
27+
int64_t repeat_dim = graph->extract_scalar<int64_t>(extra_args[1]);
28+
29+
std::vector<int64_t> new_sizes = in->sizes();
30+
repeat_dim = normalize(repeat_dim, new_sizes.size());
31+
new_sizes.at(repeat_dim) *= nrepeats;
32+
33+
out->virtual_resize(new_sizes);
34+
}
35+
36+
void add_repeat_interleave_node(
37+
ComputeGraph& graph,
38+
const ValueRef in,
39+
const ValueRef num_repeats,
40+
const ValueRef dim,
41+
const ValueRef out) {
42+
const int32_t nrepeats = graph.extract_scalar<int32_t>(num_repeats);
43+
const int32_t repeat_dim =
44+
graph.extract_whcn_dim<int32_t>(dim, graph.dim_of(in));
45+
46+
VK_CHECK_COND(repeat_dim != graph.packed_dim_of(out));
47+
VK_CHECK_COND(repeat_dim != graph.packed_dim_of(in));
48+
49+
std::string kernel_name = "repeat_interleave";
50+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
51+
52+
const utils::uvec3 global_wg_size = graph.logical_limits_of(in);
53+
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
54+
55+
graph.execute_nodes().emplace_back(new ExecuteNode(
56+
graph,
57+
// Shader
58+
VK_KERNEL_FROM_STR(kernel_name),
59+
// Workgroup sizes
60+
global_wg_size,
61+
local_wg_size,
62+
// Inputs and Outputs
63+
{{out, vkapi::MemoryAccessType::WRITE},
64+
{in, vkapi::MemoryAccessType::READ}},
65+
// Parameter buffers
66+
{graph.logical_limits_ubo(in),
67+
graph.axis_map_ubo(in),
68+
graph.axis_map_ubo(out)},
69+
// Specialization Constants
70+
{nrepeats, repeat_dim},
71+
// Resizing Logic
72+
resize_repeat_interleave_node,
73+
{num_repeats, dim}));
74+
}
75+
76+
void repeat_interleave(ComputeGraph& graph, const std::vector<ValueRef>& args) {
77+
int args_i = 0;
78+
const ValueRef in = args[args_i++];
79+
const ValueRef num_repeats = args[args_i++];
80+
const ValueRef dim = args[args_i++];
81+
const ValueRef output_size = args[args_i++];
82+
const ValueRef out = args[args_i++];
83+
84+
// Output size is not used in the kernel
85+
(void)output_size;
86+
87+
add_repeat_interleave_node(graph, in, num_repeats, dim, out);
88+
}
89+
90+
REGISTER_OPERATORS {
91+
VK_REGISTER_OP(aten.repeat_interleave.self_int, repeat_interleave);
92+
}
93+
94+
} // namespace vkcompute
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14+
15+
namespace vkcompute {
16+
17+
void add_repeat_interleave_node(
18+
ComputeGraph& graph,
19+
const ValueRef in,
20+
const ValueRef num_repeats,
21+
const ValueRef dim,
22+
const ValueRef out);
23+
24+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,46 @@ def get_repeat_inputs():
747747
return test_suite
748748

749749

750+
@register_test_suite("aten.repeat_interleave.self_int")
751+
def get_repeat_interleave_inputs():
752+
test_suite_W = VkTestSuite(
753+
[
754+
((4, 32, 256), 3, -2),
755+
# Test repeat on each non-packed dim
756+
((16, 32, 64), 5, -2),
757+
((16, 32, 64), 5, -3),
758+
# Test batched inputs
759+
((3, 5, 32, 64), 4, -2),
760+
((3, 5, 32, 64), 4, -3),
761+
]
762+
)
763+
test_suite_W.layouts = [
764+
"utils::kWidthPacked",
765+
]
766+
test_suite_W.data_gen = "make_seq_tensor"
767+
test_suite_W.dtypes = ["at::kFloat"]
768+
test_suite_W.test_name_suffix = "W_packed"
769+
770+
test_suite_C = VkTestSuite(
771+
[
772+
# Test repeat on each non-packed dim
773+
((32, 32, 16), 5, -1),
774+
((32, 32, 16), 5, -2),
775+
# Test batched inputs
776+
((3, 16, 8, 64), 4, -1),
777+
((3, 16, 8, 64), 4, -2),
778+
]
779+
)
780+
test_suite_C.layouts = [
781+
"utils::kChannelsPacked",
782+
]
783+
test_suite_C.data_gen = "make_seq_tensor"
784+
test_suite_C.dtypes = ["at::kFloat"]
785+
test_suite_C.test_name_suffix = "C_packed"
786+
787+
return [test_suite_W, test_suite_C]
788+
789+
750790
@register_test_suite("aten.cat.default")
751791
def get_cat_inputs():
752792
# TensorList must be specified as list of tuples

0 commit comments

Comments
 (0)