Skip to content

Commit 2060434

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Implement slice as a view (pytorch#5590)
Summary: Pull Request resolved: pytorch#5590 ## Context TSIA. Implement slice as a view operator. This is only valid under the following conditions: * All dims preceding the sliced dim in the dim order have a size of 1 * start is 0 * step is 1 The reasoning for these restrictions is so that the offset of the slice view with respect to the source buffer is 0. More details are in the comments. To test the operator effectively, this diff also extends the test codegen to handle multiple test suites for one operator, each with a different configuration. ghstack-source-id: 244431147 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D61666462 fbshipit-source-id: e4645ec672be0699c88eb1bb88fdef5b4e5cfdb1
1 parent 3f447d7 commit 2060434

File tree

9 files changed

+239
-13
lines changed

9 files changed

+239
-13
lines changed

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 152 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Slice.h>
14+
1315
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1416
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1517
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
@@ -31,7 +33,7 @@ inline int64_t normalize_idx(
3133
return normalize(index, max);
3234
}
3335

34-
void add_slice_tensor_out_node(
36+
void add_slice_tensor_copy_node(
3537
ComputeGraph& graph,
3638
ValueRef in,
3739
ValueRef dim_ref,
@@ -149,8 +151,126 @@ void add_slice_tensor_out_node(
149151
}
150152
}
151153

152-
void slice_tensor_out(ComputeGraph& graph, const std::vector<ValueRef>& args) {
153-
return add_slice_tensor_out_node(
154+
std::vector<int64_t> get_slice_sizes(
155+
ComputeGraph& graph,
156+
ValueRef in_ref,
157+
ValueRef dim_ref,
158+
ValueRef opt_start_ref,
159+
ValueRef opt_end_ref) {
160+
const int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
161+
std::optional<int64_t> opt_start =
162+
graph.extract_optional_scalar<int64_t>(opt_start_ref);
163+
std::optional<int64_t> opt_end =
164+
graph.extract_optional_scalar<int64_t>(opt_end_ref);
165+
166+
int64_t dim_size = graph.size_at<int64_t>(dim, in_ref);
167+
int64_t start = opt_start.value_or(0);
168+
int64_t end = opt_end.value_or(dim_size);
169+
170+
start = normalize_idx(start, dim_size, 0);
171+
end = normalize_idx(end, dim_size, dim_size);
172+
173+
std::vector<int64_t> new_out_sizes = graph.sizes_of(in_ref);
174+
new_out_sizes.at(dim) = end - start;
175+
176+
return new_out_sizes;
177+
}
178+
179+
void resize_slice_view_node(
180+
ComputeGraph* graph,
181+
const std::vector<ArgGroup>& args,
182+
const std::vector<ValueRef>& extra_args) {
183+
(void)args;
184+
vTensorPtr out = graph->get_tensor(extra_args[0]);
185+
186+
std::vector<int64_t> new_out_sizes = get_slice_sizes(
187+
*graph,
188+
extra_args[1], // input
189+
extra_args[2], // dim
190+
extra_args[3], // optional start
191+
extra_args[4]); // optional end
192+
193+
out->virtual_resize(new_out_sizes);
194+
}
195+
196+
void check_slice_view_args(
197+
ComputeGraph& graph,
198+
ValueRef in_ref,
199+
ValueRef dim_ref,
200+
ValueRef opt_start_ref,
201+
ValueRef opt_end_ref,
202+
ValueRef opt_step_ref,
203+
ValueRef out_ref) {
204+
VK_CHECK_COND(
205+
graph.val_is_view_of(out_ref, in_ref),
206+
"output must be a view of the input");
207+
208+
const int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
209+
const int64_t dim_size = graph.size_at<int64_t>(dim, in_ref);
210+
211+
int64_t start =
212+
graph.extract_optional_scalar<int64_t>(opt_start_ref).value_or(0);
213+
int64_t end = graph.extract_optional_scalar<int64_t>(opt_end_ref).value_or(0);
214+
int64_t step =
215+
graph.extract_optional_scalar<int64_t>(opt_step_ref).value_or(1);
216+
217+
start = normalize_idx(start, dim_size, 0);
218+
end = normalize_idx(end, dim_size, dim_size);
219+
220+
// The start idx must be 0; this is to ensure that the start of the slice view
221+
// does not have any offset with respect to the base buffer storage. If the
222+
// offset is nonzero, then it will potentially change upon a resize; however
223+
// the buffer offset of the view tensor will have been "locked in" when the
224+
// descriptor for its buffer storage is bound to a compute shader. Therefore
225+
// there is no way to update the offset of the view once it has been bound.
226+
VK_CHECK_COND(start == 0, "start must be 0 for slice view");
227+
VK_CHECK_COND(step == 1, "step must be 1 for slice view");
228+
229+
VK_CHECK_COND(
230+
end < dim_size, "end must be less than dim size for slice view");
231+
232+
// We must also check that all earlier dims in the dim order have a size of 1.
233+
// This ensures that the slice view encompasses a contiguous memory region of
234+
// the source tensor's memory buffer.
235+
std::vector<int64_t> in_sizes = graph.sizes_of(in_ref);
236+
std::vector<int64_t> in_dim_order = graph.dim_order_of(in_ref);
237+
for (int i = 0; i < in_dim_order.size(); ++i) {
238+
if (in_dim_order[i] == dim) {
239+
break;
240+
}
241+
VK_CHECK_COND(in_sizes[in_dim_order[i]] == 1);
242+
}
243+
}
244+
245+
void add_slice_view_node(
246+
ComputeGraph& graph,
247+
ValueRef in_ref,
248+
ValueRef dim_ref,
249+
ValueRef opt_start_ref,
250+
ValueRef opt_end_ref,
251+
ValueRef opt_step_ref,
252+
ValueRef out_ref) {
253+
check_slice_view_args(
254+
graph,
255+
in_ref,
256+
dim_ref,
257+
opt_start_ref,
258+
opt_end_ref,
259+
opt_step_ref,
260+
out_ref);
261+
262+
std::vector<int64_t> new_out_sizes =
263+
get_slice_sizes(graph, in_ref, dim_ref, opt_start_ref, opt_end_ref);
264+
265+
graph.get_tensor(out_ref)->virtual_resize(new_out_sizes);
266+
267+
graph.execute_nodes().emplace_back(new ExecuteNode(
268+
resize_slice_view_node,
269+
{out_ref, in_ref, dim_ref, opt_start_ref, opt_end_ref, opt_step_ref}));
270+
}
271+
272+
void slice_tensor_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
273+
return add_slice_tensor_copy_node(
154274
graph,
155275
args[0],
156276
args[1], // dim
@@ -160,9 +280,36 @@ void slice_tensor_out(ComputeGraph& graph, const std::vector<ValueRef>& args) {
160280
args[5]);
161281
}
162282

283+
void slice_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
284+
ValueRef in = args[0];
285+
ValueRef out = args[5];
286+
287+
// Special case if out is a view of in
288+
if (graph.val_is_view_of(out, in)) {
289+
add_slice_view_node(
290+
graph,
291+
in,
292+
args[1], // dim
293+
args[2], // optional start
294+
args[3], // optional end
295+
args[4], // step
296+
out);
297+
return;
298+
}
299+
300+
add_slice_tensor_copy_node(
301+
graph,
302+
in,
303+
args[1], // dim
304+
args[2], // optional start
305+
args[3], // optional end
306+
args[4], // step
307+
out);
308+
}
309+
163310
REGISTER_OPERATORS {
164-
VK_REGISTER_OP(aten.slice_copy.Tensor, slice_tensor_out);
165-
VK_REGISTER_OP(aten.slice.Tensor, slice_tensor_out);
311+
VK_REGISTER_OP(aten.slice_copy.Tensor, slice_tensor_copy);
312+
VK_REGISTER_OP(aten.slice.Tensor, slice_tensor);
166313
}
167314

168315
} // namespace vkcompute
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14+
15+
#include <vector>
16+
17+
namespace vkcompute {
18+
19+
void add_slice_view_node(
20+
ComputeGraph& graph,
21+
ValueRef in_ref,
22+
ValueRef dim_ref,
23+
ValueRef opt_start_ref,
24+
ValueRef opt_end_ref,
25+
ValueRef opt_step_ref,
26+
ValueRef out_ref);
27+
28+
} // namespace vkcompute

backends/vulkan/test/glsl/scalar_add_texture.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
layout(std430) buffer;
1414

1515
${layout_declare_tensor(0, "rw", "t_in", "float", "texture3d")}
16-
${layout_declare_ubo(1, "uvec3", "extents")}
16+
${layout_declare_ubo(1, "ivec3", "extents")}
1717
${layout_declare_ubo(2, "int", "scalar")}
1818

1919
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

backends/vulkan/test/op_tests/cases.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,8 @@ def get_view_inputs():
466466
return test_suite
467467

468468

469-
@register_test_suite(["aten.slice.Tensor", "aten.slice_copy.Tensor"])
470-
def get_slice_inputs():
469+
@register_test_suite("aten.slice_copy.Tensor")
470+
def get_slice_out_inputs():
471471
Test = namedtuple("VkSliceTest", ["self", "dim", "start", "end", "step"])
472472
Test.__new__.__defaults__ = (None, 0, None, None, 1)
473473

@@ -549,6 +549,39 @@ def get_slice_inputs():
549549
return test_suite
550550

551551

552+
def get_slice_view_inputs():
553+
Test = namedtuple("VkSliceTest", ["self", "dim", "start", "end", "step"])
554+
Test.__new__.__defaults__ = (None, 0, None, None, 1)
555+
556+
# Slice by channel
557+
test_cases = [
558+
Test(self=[1, 17, 1, 10], dim=1, start=0, end=4),
559+
Test(self=[1, 17, 1, 10], dim=1, start=0, end=8),
560+
Test(self=[1, 17, 3, 7], dim=1, start=0, end=12),
561+
]
562+
563+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
564+
565+
test_suite.dtypes = ["at::kFloat"]
566+
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
567+
test_suite.layouts = ["utils::kWidthPacked"]
568+
test_suite.data_gen = "make_seq_tensor"
569+
test_suite.is_view_op = True
570+
571+
return test_suite
572+
573+
574+
@register_test_suite(["aten.slice.Tensor"])
575+
def get_slice_inputs():
576+
texture_test_suite = get_slice_out_inputs()
577+
texture_test_suite.test_name_suffix = "no_view"
578+
579+
view_test_suite = get_slice_view_inputs()
580+
view_test_suite.test_name_suffix = "view"
581+
582+
return [view_test_suite, texture_test_suite]
583+
584+
552585
@register_test_suite(["aten.transpose.int"])
553586
def get_transpose_inputs():
554587
Test = namedtuple("VkTransposeViewTest", ["self", "dim0", "dim1"])

backends/vulkan/test/op_tests/generate_op_benchmarks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ def process_test_suites(
4343
f_map: Dict[str, NativeFunction],
4444
test_suites: Dict[str, TestSuite],
4545
) -> None:
46-
for registry_name, op_test_suite in test_suites.items():
46+
for registry_name, op_test_suites in test_suites.items():
4747
f = f_map[registry_name]
48-
cpp_generator.add_suite(registry_name, f, op_test_suite)
48+
if isinstance(op_test_suites, list):
49+
for suite in op_test_suites:
50+
cpp_generator.add_suite(registry_name, f, suite)
51+
else:
52+
cpp_generator.add_suite(registry_name, f, op_test_suites)
4953

5054

5155
@local.parametrize(

backends/vulkan/test/op_tests/generate_op_correctness_tests.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ def process_test_suites(
4343
f_map: Dict[str, NativeFunction],
4444
test_suites: Dict[str, TestSuite],
4545
) -> None:
46-
for registry_name, op_test_suite in test_suites.items():
46+
for registry_name, op_test_suites in test_suites.items():
4747
f = f_map[registry_name]
48-
cpp_generator.add_suite(registry_name, f, op_test_suite)
48+
if isinstance(op_test_suites, list):
49+
for suite in op_test_suites:
50+
cpp_generator.add_suite(registry_name, f, suite)
51+
else:
52+
cpp_generator.add_suite(registry_name, f, op_test_suites)
4953

5054

5155
@local.parametrize(

backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,9 @@ def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str:
658658

659659
def gen_op_check_fn(self) -> str:
660660
op_name = self.f.func.name.unambiguous_name()
661+
if self.suite_def.test_name_suffix is not None:
662+
op_name += "_" + self.suite_def.test_name_suffix
663+
661664
op_check_fn = self.gen_decl(f"check_{op_name}") + " {\n"
662665
if self.should_prepack:
663666
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n"
@@ -676,6 +679,8 @@ def gen_op_check_fn(self) -> str:
676679

677680
def gen_build_graph_fn(self, include_declarations: bool = False) -> str:
678681
op_name = self.f.func.name.unambiguous_name()
682+
if self.suite_def.test_name_suffix is not None:
683+
op_name += "_" + self.suite_def.test_name_suffix
679684
op_build_graph_fn = self.gen_decl(f"build_graph_{op_name}") + " {\n"
680685
if self.should_prepack:
681686
op_build_graph_fn = (
@@ -691,6 +696,8 @@ def gen_build_graph_fn(self, include_declarations: bool = False) -> str:
691696

692697
def gen_op_exec_graph_fn(self) -> str:
693698
op_name = self.f.func.name.unambiguous_name()
699+
if self.suite_def.test_name_suffix is not None:
700+
op_name += "_" + self.suite_def.test_name_suffix
694701
op_benchmark_fn = self.gen_decl(f"benchmark_{op_name}") + " {\n"
695702
if self.should_prepack:
696703
op_benchmark_fn = self.gen_decl(f"prepacked_benchmark_{op_name}") + " {\n"

backends/vulkan/test/op_tests/utils/gen_correctness_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def __init__(self, f: NativeFunction, test_suite: TestSuite):
7979
self.f = f
8080
self.suite_def = test_suite
8181
self.op_name = f.func.name.unambiguous_name()
82+
if test_suite.test_name_suffix is not None:
83+
self.op_name += f"_{test_suite.test_name_suffix}"
8284

8385
self.f_sig = CppSignatureGroup.from_native_function(
8486
self.f, method=False, fallback_binding=self.f.manual_cpp_binding

backends/vulkan/test/op_tests/utils/test_suite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List
8+
from typing import Any, List, Optional
99

1010
###################################
1111
## Generic Test Suite definition ##
@@ -29,6 +29,7 @@ def __init__(self, input_cases: List[Any]):
2929
self.rtol: str = "1e-5"
3030

3131
self.is_view_op: bool = False
32+
self.test_name_suffix: Optional[str] = None
3233

3334
def supports_prepack(self):
3435
return len(self.prepacked_args) > 0

0 commit comments

Comments
 (0)