Skip to content

Commit 3f447d7

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add transpose op as view operator (pytorch#5589)
Summary: Pull Request resolved: pytorch#5589 ## Context As title. Implement `aten.transpose.int` as a view operator, which creates an alias of the input tensor with different sizes and strides. To effectively test the op, the codegen script is also updated to support view ops. ghstack-source-id: 244431150 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D61666463 fbshipit-source-id: 207ab4d88522d437bb059b6be6ebf5204ff21275
1 parent 90dcea5 commit 3f447d7

File tree

9 files changed

+280
-5
lines changed

9 files changed

+280
-5
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,7 @@ void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) {
707707
const int dim1_whcn = sizes_.size() - 1 - dim1;
708708
if (packed_dim_ == dim0_whcn) {
709709
packed_dim_ = dim1_whcn;
710-
}
711-
if (packed_dim_ == dim1_whcn) {
710+
} else if (packed_dim_ == dim1_whcn) {
712711
packed_dim_ = dim0_whcn;
713712
}
714713

@@ -719,6 +718,12 @@ void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) {
719718
VK_CHECK_COND(dim0_whcn < 3 && dim1_whcn < 3);
720719
std::iter_swap(
721720
axis_map_.begin() + dim0_whcn, axis_map_.begin() + dim1_whcn);
721+
// Update the "identity" of the concatted dimension
722+
if (axis_map_.at(3) == dim0_whcn) {
723+
axis_map_.at(3) = dim1_whcn;
724+
} else if (axis_map_.at(3) == dim1_whcn) {
725+
axis_map_.at(3) = dim0_whcn;
726+
}
722727
}
723728
update_metadata();
724729
}

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,32 @@ std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
198198
VK_THROW("Could not get sizes of value with type ", val.type());
199199
}
200200

201+
int64_t ComputeGraph::dim_of(const ValueRef idx) const {
202+
const Value& val = values_.at(idx);
203+
if (val.isTensor()) {
204+
return val.toConstTensor().dim();
205+
} else if (val.isTensorRef()) {
206+
return val.toConstTensorRef().sizes.size();
207+
}
208+
VK_THROW("Could not get dim of value with type ", val.type());
209+
}
210+
211+
std::vector<int64_t> ComputeGraph::dim_order_of(const ValueRef idx) const {
212+
const Value& val = values_.at(idx);
213+
if (val.isTensor()) {
214+
return val.toConstTensor().dim_order();
215+
}
216+
VK_THROW("Could not get dim order of value with type ", val.type());
217+
}
218+
219+
std::vector<int64_t> ComputeGraph::strides_of(const ValueRef idx) const {
220+
const Value& val = values_.at(idx);
221+
if (val.isTensor()) {
222+
return val.toConstTensor().strides();
223+
}
224+
VK_THROW("Could not get strides of value with type ", val.type());
225+
}
226+
201227
vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
202228
const Value& val = values_.at(idx);
203229
if (val.isTensor()) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,12 @@ class ComputeGraph final {
282282
VK_THROW("Could not get sizes of value with type ", val.type());
283283
}
284284

285+
int64_t dim_of(const ValueRef idx) const;
286+
287+
std::vector<int64_t> dim_order_of(const ValueRef idx) const;
288+
289+
std::vector<int64_t> strides_of(const ValueRef idx) const;
290+
285291
vkapi::ScalarType dtype_of(const ValueRef idx) const;
286292

287293
inline const utils::ivec3& logical_limits_of(const ValueRef idx) const {
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
#include <algorithm>
18+
19+
namespace vkcompute {
20+
21+
void resize_transpose_view_node(
22+
ComputeGraph* graph,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& extra_args) {
25+
(void)args;
26+
vTensorPtr out = graph->get_tensor(extra_args[0]);
27+
vTensorPtr in = graph->get_tensor(extra_args[1]);
28+
29+
const int64_t dim0 = graph->extract_scalar<int64_t>(extra_args[2]);
30+
const int64_t dim1 = graph->extract_scalar<int64_t>(extra_args[3]);
31+
32+
std::vector<int64_t> new_sizes = in->sizes();
33+
// Transpose the resized input sizes
34+
std::iter_swap(new_sizes.begin() + dim0, new_sizes.begin() + dim1);
35+
out->virtual_resize(new_sizes);
36+
}
37+
38+
void check_transpose_view_args(
39+
ComputeGraph& graph,
40+
ValueRef in_ref,
41+
const int64_t dim0,
42+
const int64_t dim1,
43+
ValueRef out_ref) {
44+
VK_CHECK_COND(
45+
graph.val_is_view_of(out_ref, in_ref),
46+
"output tensor must be a view of the input tensor");
47+
48+
const int64_t in_ndim = graph.dim_of(in_ref);
49+
VK_CHECK_COND(
50+
dim0 >= 0 && dim0 < in_ndim, "dim0 is not in the range of [0, in_ndim)");
51+
VK_CHECK_COND(
52+
dim1 >= 0 && dim1 < in_ndim, "dim1 is not in the range of [0, in_ndim)");
53+
}
54+
55+
void add_transpose_view_node(
56+
ComputeGraph& graph,
57+
ValueRef input_ref,
58+
ValueRef dim0_ref,
59+
ValueRef dim1_ref,
60+
ValueRef out_ref) {
61+
const int64_t dim0 = graph.extract_scalar<int64_t>(dim0_ref);
62+
const int64_t dim1 = graph.extract_scalar<int64_t>(dim1_ref);
63+
64+
check_transpose_view_args(graph, input_ref, dim0, dim1, out_ref);
65+
graph.get_tensor(out_ref)->virtual_transpose(dim0, dim1);
66+
67+
graph.execute_nodes().emplace_back(new ExecuteNode(
68+
resize_transpose_view_node, {out_ref, input_ref, dim0_ref, dim1_ref}));
69+
}
70+
71+
void transpose(ComputeGraph& graph, const std::vector<ValueRef>& args) {
72+
const ValueRef out = args[3];
73+
return add_transpose_view_node(
74+
graph,
75+
args[0], // input
76+
args[1], // dim0
77+
args[2], // dim1
78+
out);
79+
}
80+
81+
REGISTER_OPERATORS {
82+
VK_REGISTER_OP(aten.transpose.int, transpose);
83+
}
84+
85+
} // namespace vkcompute
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14+
15+
#include <vector>
16+
17+
namespace vkcompute {
18+
19+
void add_transpose_view_node(
20+
ComputeGraph& graph,
21+
ValueRef input_ref,
22+
ValueRef dim0_ref,
23+
ValueRef dim1_ref,
24+
ValueRef out_ref);
25+
26+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,31 @@ def get_slice_inputs():
549549
return test_suite
550550

551551

552+
@register_test_suite(["aten.transpose.int"])
553+
def get_transpose_inputs():
554+
Test = namedtuple("VkTransposeViewTest", ["self", "dim0", "dim1"])
555+
Test.__new__.__defaults__ = (None, 0, 1)
556+
557+
test_cases = [
558+
Test(self=[M1, M2], dim0=0, dim1=1),
559+
Test(self=[M1, S2, M], dim0=0, dim1=1),
560+
Test(self=[M1, S2, M], dim0=0, dim1=2),
561+
Test(self=[M1, S2, M], dim0=2, dim1=1),
562+
Test(self=[S, M, S2, M2], dim0=3, dim1=2),
563+
Test(self=[S, M, S2, M2], dim0=1, dim1=2),
564+
Test(self=[S, M, S2, M2], dim0=3, dim1=1),
565+
]
566+
567+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
568+
569+
test_suite.dtypes = ["at::kFloat"]
570+
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
571+
test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"]
572+
test_suite.data_gen = "make_seq_tensor"
573+
test_suite.is_view_op = True
574+
return test_suite
575+
576+
552577
@register_test_suite("aten.index_select.default")
553578
def get_index_select_inputs():
554579
Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"])

backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def create_value_for( # noqa: C901
272272
return ret_str
273273

274274
prepack = self.prepack_ref(ref)
275+
ref_is_view = self.suite_def.is_view_op and ref.is_out
275276

276277
cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
277278
if not include_declarations:
@@ -362,7 +363,15 @@ def create_value_for( # noqa: C901
362363
ret_str = f"IOValueRef {ref.name};\n"
363364
ret_str += f"{ref.name}.value = {self.graph}{self.dot}"
364365

365-
if ref.src_cpp_type == AT_TENSOR and not prepack:
366+
if ref.src_cpp_type == AT_TENSOR and ref_is_view:
367+
input_name = None
368+
for _name, ref in self.refs.items():
369+
if ref.is_in and ref.src_cpp_type == AT_TENSOR:
370+
input_name = ref.name
371+
372+
assert input_name is not None
373+
ret_str += f"add_tensor_view({input_name}.value);"
374+
elif ref.src_cpp_type == AT_TENSOR and not prepack:
366375
ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
367376
ret_str += f"{ref.src_cpp_name}.sizes().vec(), "
368377
ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n"

backends/vulkan/test/op_tests/utils/test_suite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(self, input_cases: List[Any]):
2828
self.atol: str = "1e-5"
2929
self.rtol: str = "1e-5"
3030

31+
self.is_view_op: bool = False
32+
3133
def supports_prepack(self):
3234
return len(self.prepacked_args) > 0
3335

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ TEST_F(VulkanComputeAPITest, virtual_transpose_test) {
264264
// (dim0, dim1), new_sizes, new_dim_order, new_axis_map, new_packed_dim_idx
265265
std::vector<std::vector<std::vector<int64_t>>> test_cases = {
266266
{{2, 3}, {7, 9, 13, 11}, {0, 1, 3, 2}, {1, 0, 2, 2}, {1}},
267-
{{2, 1}, {7, 11, 9, 13}, {0, 2, 1, 3}, {0, 2, 1, 2}, {0}},
268-
{{1, 3}, {7, 13, 11, 9}, {0, 3, 2, 1}, {2, 1, 0, 2}, {2}},
267+
{{2, 1}, {7, 11, 9, 13}, {0, 2, 1, 3}, {0, 2, 1, 1}, {0}},
268+
{{1, 3}, {7, 13, 11, 9}, {0, 3, 2, 1}, {2, 1, 0, 0}, {2}},
269269
};
270270

271271
for (const auto& test_case : test_cases) {
@@ -3039,3 +3039,94 @@ TEST(VulkanComputeGraphOpsTest, int4pack_mm_test) {
30393039
test_int4pack_mm({37, 256, 19}, 64, storage_type);
30403040
}
30413041
}
3042+
3043+
void test_transpose_view_mm(
3044+
const int B,
3045+
const int M,
3046+
const int K,
3047+
const int N,
3048+
utils::StorageType storage_type) {
3049+
GraphConfig config;
3050+
config.set_storage_type_override(storage_type);
3051+
ComputeGraph graph(config);
3052+
3053+
std::vector<int64_t> mat1_size = {M, K};
3054+
std::vector<int64_t> mat2_t_size = {N, K};
3055+
std::vector<int64_t> out_size = {M, N};
3056+
3057+
std::vector<int64_t> mat1_small_size = {M - 4, K - 3};
3058+
std::vector<int64_t> mat2_t_small_size = {N - 1, K - 3};
3059+
3060+
if (B > 1) {
3061+
mat1_size.resize(3);
3062+
mat1_size = {B, M, K};
3063+
mat2_t_size.resize(3);
3064+
mat2_t_size = {B, N, K};
3065+
out_size.resize(3);
3066+
out_size = {B, M, N};
3067+
3068+
mat1_small_size.resize(3);
3069+
mat1_small_size = {B, M - 4, K - 3};
3070+
mat2_t_small_size.resize(3);
3071+
mat2_t_small_size = {B, N - 1, K - 3};
3072+
}
3073+
3074+
// Build graph
3075+
3076+
IOValueRef mat1 =
3077+
graph.add_input_tensor(mat1_size, vkapi::kFloat, utils::kWidthPacked);
3078+
IOValueRef mat2_transpose =
3079+
graph.add_input_tensor(mat2_t_size, vkapi::kFloat, utils::kWidthPacked);
3080+
3081+
ValueRef mat2 = graph.add_tensor_view(mat2_transpose.value);
3082+
3083+
ValueRef dim0;
3084+
ValueRef dim1;
3085+
3086+
if (B > 1) {
3087+
dim0 = graph.add_scalar<int64_t>(1);
3088+
dim1 = graph.add_scalar<int64_t>(2);
3089+
} else {
3090+
dim0 = graph.add_scalar<int64_t>(0);
3091+
dim1 = graph.add_scalar<int64_t>(1);
3092+
}
3093+
3094+
IOValueRef out;
3095+
out.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kWidthPacked);
3096+
3097+
VK_GET_OP_FN("aten.transpose.int")
3098+
(graph, {mat2_transpose.value, dim0, dim1, mat2});
3099+
VK_GET_OP_FN("aten.mm.default")(graph, {mat1.value, mat2, out.value});
3100+
3101+
out.staging = graph.set_output_tensor(out.value);
3102+
3103+
graph.prepare();
3104+
graph.encode_prepack();
3105+
graph.prepack();
3106+
graph.encode_execute();
3107+
3108+
for (int i = 1; i < 4; i++) {
3109+
float val_mat1 = i;
3110+
float val_mat2 = i + 1;
3111+
float val_out = K * (val_mat1 * val_mat2);
3112+
3113+
// Try at full size
3114+
graph.resize_input(0, mat1_size);
3115+
graph.resize_input(1, mat2_t_size);
3116+
graph.propagate_resize();
3117+
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
3118+
3119+
// Try at reduced sizes
3120+
val_out = (K - 3) * (val_mat1 * val_mat2);
3121+
graph.resize_input(0, mat1_small_size);
3122+
graph.resize_input(1, mat2_t_small_size);
3123+
graph.propagate_resize();
3124+
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
3125+
}
3126+
}
3127+
3128+
TEST(VulkanComputeGraphOpsTest, test_transpose_with_mm) {
3129+
for (auto storage_type : {utils::kBuffer, utils::kTexture3D}) {
3130+
test_transpose_view_mm(2, 7, 17, 5, storage_type);
3131+
}
3132+
}

0 commit comments

Comments
 (0)