Skip to content

Commit 1f60bb7

Browse files
committed
Add clone_dim_order kernel (portable + ATen) and layout conversion tests
1 parent d1c87e4 commit 1f60bb7

File tree

8 files changed

+233
-0
lines changed

8 files changed

+233
-0
lines changed

exir/tests/test_memory_format_ops_pass.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
MemoryFormatOpsPassTestUtils,
2929
MemoryFormatTestSet,
3030
PropagateToCopyChannalsLastModule,
31+
SimpleCloneChannelsLastModule,
32+
SimpleCloneContiguousModule,
3133
SimpleEmptyChannelLastModule,
3234
SimpleEmptyContiguoustModule,
3335
SimpleToCopyChannelsLastModule,
@@ -91,6 +93,36 @@ def test_op_empty_replacement_contiguous(self) -> None:
9193
),
9294
)
9395

96+
def test_op_clone_replacement_contiguous(self) -> None:
97+
model = SimpleCloneContiguousModule()
98+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
99+
self,
100+
MemoryFormatTestSet(
101+
module=model.eval(),
102+
op=torch.ops.aten.clone.default,
103+
sample_input=(
104+
torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last),
105+
),
106+
target_memory_format=torch.contiguous_format,
107+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
108+
),
109+
)
110+
111+
def test_op_clone_replacement_channels_last(self) -> None:
112+
model = SimpleCloneChannelsLastModule()
113+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
114+
self,
115+
MemoryFormatTestSet(
116+
module=model.eval(),
117+
op=torch.ops.aten.clone.default,
118+
sample_input=(
119+
torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format),
120+
),
121+
target_memory_format=torch.channels_last,
122+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
123+
),
124+
)
125+
94126
def test_op_dim_order_update(self) -> None:
95127
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
96128
self,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/aten/cpu/util/copy_ops_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
using Tensor = executorch::aten::Tensor;
18+
using MemoryFormat = executorch::aten::MemoryFormat;
19+
20+
template <typename T>
21+
using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
22+
23+
template <typename T>
24+
using Optional = std::optional<T>;
25+
26+
/**
27+
* _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
28+
* dim_order=None, Tensor(a!) out) -> Tensor(a!)
29+
*
30+
* Clones with explicit dim_order, using the corresponding memory format.
31+
*/
32+
Tensor& _clone_dim_order_out(
33+
KernelRuntimeContext& ctx,
34+
const Tensor& self,
35+
bool non_blocking,
36+
OptionalArrayRef<int64_t> dim_order,
37+
Tensor& out) {
38+
// Ensure output has the same layout as input or matches dim_order.
39+
ET_KERNEL_CHECK(
40+
ctx,
41+
check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
42+
InvalidArgument,
43+
out);
44+
45+
Optional<MemoryFormat> memory_format = get_memory_format(dim_order);
46+
at::clone_outf(self, memory_format, out);
47+
48+
return out;
49+
}
50+
51+
Tensor& _clone_dim_order_out(
52+
const Tensor& self,
53+
bool non_blocking,
54+
OptionalArrayRef<int64_t> dim_order,
55+
Tensor& out) {
56+
KernelRuntimeContext ctx{};
57+
return _clone_dim_order_out(ctx, self, non_blocking, dim_order, out);
58+
}
59+
60+
} // namespace native
61+
} // namespace executor
62+
} // namespace torch

kernels/aten/cpu/util/copy_ops_util.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@ namespace torch {
1515
namespace executor {
1616

1717
using Tensor = executorch::aten::Tensor;
18+
using MemoryFormat = executorch::aten::MemoryFormat;
19+
20+
/**
21+
* Determines the memory format (Contiguous or ChannelsLast) corresponding to
22+
* the dim_order. Provides support for bridging torch.memory_format with
23+
* ExecuTorch's dim_order.
24+
*/
25+
std::optional<MemoryFormat> get_memory_format(
26+
executorch::aten::OptionalArrayRef<int64_t> dim_order) {
27+
if (!dim_order.has_value()) {
28+
return executorch::aten::nullopt;
29+
}
30+
if (is_contiguous_dim_order(
31+
dim_order.value().data(), dim_order.value().size())) {
32+
return MemoryFormat::Contiguous;
33+
} else if (is_channels_last_dim_order(
34+
dim_order.value().data(), dim_order.value().size())) {
35+
return MemoryFormat::ChannelsLast;
36+
} else {
37+
ET_ASSERT_UNREACHABLE();
38+
}
39+
}
1840

1941
bool check__to_dim_order_copy_args(
2042
const Tensor& input,

kernels/aten/cpu/util/copy_ops_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
namespace torch {
1414
namespace executor {
1515

16+
std::optional<MemoryFormat> get_memory_format(
17+
executorch::aten::OptionalArrayRef<int64_t> dim_order);
18+
1619
bool check__to_dim_order_copy_args(
1720
const Tensor& input,
1821
bool non_blocking,

kernels/aten/edge_dialect_aten_op.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@
1111
kernels:
1212
- arg_meta: null
1313
kernel_name: torch::executor::_to_dim_order_copy_out
14+
15+
- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
16+
kernels:
17+
- arg_meta: null
18+
kernel_name: torch::executor::_clone_dim_order_out
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
using Tensor = executorch::aten::Tensor;
18+
19+
template <typename T>
20+
using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
21+
22+
/**
23+
* _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
24+
* dim_order=None, Tensor(a!) out) -> Tensor(a!)
25+
*
26+
* Clones via element-wise copy while preserving dim_order.
27+
*/
28+
Tensor& _clone_dim_order_out(
29+
KernelRuntimeContext& ctx,
30+
const Tensor& self,
31+
bool non_blocking,
32+
OptionalArrayRef<int64_t> dim_order,
33+
Tensor& out) {
34+
(void)ctx;
35+
36+
// Ensure input and output dtype match.
37+
ET_KERNEL_CHECK(
38+
ctx, self.scalar_type() == out.scalar_type(), InvalidArgument, out);
39+
40+
// Ensure output has the same layout as input or matches dim_order.
41+
ET_KERNEL_CHECK(
42+
ctx,
43+
check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
44+
InvalidArgument,
45+
out);
46+
47+
// Ensure input and output shapes match, resizing if necessary.
48+
ET_KERNEL_CHECK(
49+
ctx,
50+
resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
51+
InvalidArgument,
52+
out);
53+
54+
if (self.numel() == 0) {
55+
return out;
56+
}
57+
58+
// Select the correct input dtype and copy the tensors.
59+
ET_SWITCH_REALHBBF16_TYPES(
60+
self.scalar_type(),
61+
ctx,
62+
"dim_order_ops::_clone_dim_order.out",
63+
CTYPE,
64+
[&] { _to_dim_order_copy_impl<CTYPE, CTYPE>(self, out); });
65+
66+
return out;
67+
}
68+
69+
Tensor& _clone_dim_order_out(
70+
const Tensor& self,
71+
bool non_blocking,
72+
OptionalArrayRef<int64_t> dim_order,
73+
Tensor& out) {
74+
executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext context{};
75+
return _clone_dim_order_out(context, self, non_blocking, dim_order, out);
76+
}
77+
78+
} // namespace native
79+
} // namespace executor
80+
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010
#include <c10/util/irange.h>
1111

12+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

1415
namespace torch {
@@ -77,6 +78,29 @@ void as_strided_copy(
7778
}
7879
}
7980

81+
/**
82+
* Copies and casts a tensor while preserving input dim_order.
83+
*/
84+
template <typename SELF_CTYPE, typename OUT_CTYPE>
85+
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
86+
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
87+
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
88+
89+
// Here we make a slightly off-label use of
90+
// BroadcastIndexesRange. It always assumes it doesn't have to care
91+
// about different dim_order between input and output, but we can
92+
// just force it to respect strides (and thus dim_order) for its
93+
// inputs using support_noncontiguous_input_tensors=true, and then pretend
94+
// the output is just another input.
95+
for (const auto [unused_index, self_data_index, out_data_index] :
96+
BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>(
97+
/*dummy output*/ self, self, out)) {
98+
(void)unused_index;
99+
out_data[out_data_index] =
100+
static_cast<OUT_CTYPE>(self_data[self_data_index]);
101+
}
102+
}
103+
80104
bool check_cat_args(
81105
executorch::aten::ArrayRef<Tensor> tensors,
82106
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,3 +1009,8 @@
10091009
kernels:
10101010
- arg_meta: null
10111011
kernel_name: torch::executor::_to_dim_order_copy_out
1012+
1013+
- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
1014+
kernels:
1015+
- arg_meta: null
1016+
kernel_name: torch::executor::_clone_dim_order_out

0 commit comments

Comments
 (0)