Skip to content

Commit b56e8b7

Browse files
aten mode clone dim order op (#14382)
Differential Revision: D82558256 Co-authored-by: Gasoonjia <[email protected]>
1 parent 93274bb commit b56e8b7

File tree

6 files changed

+144
-6
lines changed

6 files changed

+144
-6
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using Tensor = executorch::aten::Tensor;
17+
using SizesArrayRef = executorch::aten::ArrayRef<executorch::aten::SizesType>;
18+
using DimOrderArrayRef =
19+
executorch::aten::ArrayRef<executorch::aten::DimOrderType>;
20+
using MemoryFormat = executorch::aten::MemoryFormat;
21+
22+
template <typename T>
23+
using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
24+
25+
template <typename T>
26+
using Optional = std::optional<T>;
27+
28+
namespace {
29+
Optional<MemoryFormat> get_memory_format(OptionalArrayRef<int64_t> dim_order) {
30+
if (!dim_order.has_value()) {
31+
return executorch::aten::nullopt;
32+
}
33+
if (is_contiguous_dim_order(
34+
dim_order.value().data(), dim_order.value().size())) {
35+
return MemoryFormat::Contiguous;
36+
} else if (is_channels_last_dim_order(
37+
dim_order.value().data(), dim_order.value().size())) {
38+
return MemoryFormat::ChannelsLast;
39+
} else {
40+
ET_ASSERT_UNREACHABLE();
41+
}
42+
}
43+
44+
bool check__clone_dim_order_args(
45+
const Tensor& input,
46+
bool non_blocking,
47+
executorch::aten::OptionalArrayRef<int64_t> dim_order,
48+
Tensor& out) {
49+
// Right now we only support blocking data transfer
50+
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);
51+
52+
// Ensure input and output dtype match
53+
ET_LOG_AND_RETURN_IF_FALSE(input.scalar_type() == out.scalar_type());
54+
55+
// dim_order is set, the target dim_order will be either contiguous or
56+
// channels_last memory format
57+
if (dim_order.has_value()) {
58+
executorch::aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();
59+
60+
// dim order size shall equal to input dim
61+
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());
62+
63+
ET_LOG_AND_RETURN_IF_FALSE(
64+
is_channels_last_dim_order(
65+
dim_order.value().data(), dim_order.value().size()) ||
66+
is_contiguous_dim_order(
67+
dim_order.value().data(), dim_order.value().size()));
68+
69+
// Out Aten tensor shall have same memory format stride as dim_order
70+
const size_t kMaxNumOfDimensions = 16;
71+
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
72+
executorch::aten::StridesType target_strides[kMaxNumOfDimensions];
73+
dim_order_to_stride_nocheck(
74+
out.sizes().data(),
75+
dim_order_ref.data(),
76+
dim_order_ref.size(),
77+
target_strides);
78+
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
79+
for (size_t i = 0; i < dim_order_ref.size(); i++) {
80+
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
81+
}
82+
83+
} else { // dim_order is not set, preserve the dim order of input
84+
85+
auto out_strides = out.strides();
86+
auto input_strides = input.strides();
87+
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
88+
for (size_t i = 0; i < input_strides.size(); i++) {
89+
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
90+
}
91+
}
92+
return true;
93+
}
94+
} // namespace
95+
96+
// _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
97+
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
98+
Tensor& _clone_dim_order_out(
99+
KernelRuntimeContext& ctx,
100+
const Tensor& self,
101+
bool non_blocking,
102+
OptionalArrayRef<int64_t> dim_order,
103+
Tensor& out) {
104+
// TODO(T181345875): enable sanity check in aten mode
105+
ET_KERNEL_CHECK(
106+
ctx,
107+
check__clone_dim_order_args(self, non_blocking, dim_order, out),
108+
InvalidArgument,
109+
out);
110+
111+
Optional<MemoryFormat> memory_format = get_memory_format(dim_order);
112+
at::clone_outf(self, memory_format, out);
113+
114+
return out;
115+
}
116+
117+
Tensor& _clone_dim_order_out(
118+
const Tensor& self,
119+
bool non_blocking,
120+
OptionalArrayRef<int64_t> dim_order,
121+
Tensor& out) {
122+
KernelRuntimeContext ctx{};
123+
return _clone_dim_order_out(ctx, self, non_blocking, dim_order, out);
124+
}
125+
126+
} // namespace native
127+
} // namespace executor
128+
} // namespace torch

kernels/aten/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ _EDGE_DIALECT_OPS = (
1818
"//executorch/kernels/aten/cpu/util:copy_ops_util",
1919
],
2020
),
21+
op_target(
22+
name = "op__clone_dim_order",
23+
deps = [
24+
"//executorch/kernels/aten/cpu/util:copy_ops_util",
25+
],
26+
),
2127
)
2228

2329
def define_common_targets():

kernels/aten/edge_dialect_aten_op.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@
1111
kernels:
1212
- arg_meta: null
1313
kernel_name: torch::executor::_to_dim_order_copy_out
14+
15+
- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
16+
kernels:
17+
- arg_meta: null
18+
kernel_name: torch::executor::_clone_dim_order_out

kernels/test/op__clone_dim_order_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
*/
88

99
#include <cstdint>
10-
#include <map>
11-
#include <typeindex>
12-
#include <variant>
1310

1411
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator.
1512
#include <executorch/kernels/test/TestUtil.h>

kernels/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def define_common_targets():
177177

178178
_common_op_test("op__to_dim_order_copy_test", ["aten", "portable"])
179179
_common_op_test("op__empty_dim_order_test", ["aten", "portable"])
180-
_common_op_test("op__clone_dim_order_test", ["portable"])
180+
_common_op_test("op__clone_dim_order_test", ["aten", "portable"])
181181
_common_op_test("op_abs_test", ["aten", "portable"])
182182
_common_op_test("op_acos_test", ["aten", "portable"])
183183
_common_op_test("op_acosh_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/test/util.bzl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ def op_test(name, deps = [], kernel_name = "portable", use_kernel_prefix = False
2121
if kernel_name == "aten":
2222
generated_lib_and_op_deps = [
2323
"//executorch/kernels/aten:generated_lib",
24-
#TODO(T187390274): consolidate all aten ops into one target
25-
"//executorch/kernels/aten/cpu:op__to_dim_order_copy_aten",
2624
"//executorch/kernels/aten:generated_lib_headers",
2725
"//executorch/kernels/test:supported_features_aten",
2826
]
27+
28+
if "dim_order" in op_root:
29+
generated_lib_and_op_deps.append("//executorch/kernels/aten/cpu:" + op_root + "_aten")
30+
2931
else:
3032
generated_lib_and_op_deps = [
3133
"//executorch/kernels/{}/cpu:{}".format(kernel_name, op_root),

0 commit comments

Comments
 (0)