1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+
9+ #include < executorch/kernels/portable/cpu/scalar_utils.h>
10+ #include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
11+ #include < executorch/runtime/kernel/kernel_includes.h>
12+
13+ namespace torch {
14+ namespace executor {
15+ namespace native {
16+
17+ using Tensor = executorch::aten::Tensor;
18+
19+ template <typename T>
20+ using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
21+
22+ /* *
23+ * _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
24+ * dim_order=None, Tensor(a!) out) -> Tensor(a!)
25+ *
26+ * Clones via element-wise copy while preserving dim_order.
27+ */
28+ Tensor& _clone_dim_order_out (
29+ KernelRuntimeContext& ctx,
30+ const Tensor& self,
31+ bool non_blocking,
32+ OptionalArrayRef<int64_t > dim_order,
33+ Tensor& out) {
34+ (void )ctx;
35+
36+ // Ensure input and output dtype match.
37+ ET_KERNEL_CHECK (
38+ ctx, self.scalar_type () == out.scalar_type (), InvalidArgument, out);
39+
40+ // Ensure output has the same layout as input or matches dim_order.
41+ ET_KERNEL_CHECK (
42+ ctx,
43+ check__to_dim_order_copy_args (self, non_blocking, dim_order, out),
44+ InvalidArgument,
45+ out);
46+
47+ // Ensure input and output shapes match, resizing if necessary.
48+ ET_KERNEL_CHECK (
49+ ctx,
50+ resize_tensor (out, self.sizes ()) == torch::executor::Error::Ok,
51+ InvalidArgument,
52+ out);
53+
54+ if (self.numel () == 0 ) {
55+ return out;
56+ }
57+
58+ // Select the correct input dtype and copy the tensors.
59+ ET_SWITCH_REALHBBF16_TYPES (
60+ self.scalar_type (),
61+ ctx,
62+ " dim_order_ops::_clone_dim_order.out" ,
63+ CTYPE,
64+ [&] { _to_dim_order_copy_impl<CTYPE, CTYPE>(self, out); });
65+
66+ return out;
67+ }
68+
69+ Tensor& _clone_dim_order_out (
70+ const Tensor& self,
71+ bool non_blocking,
72+ OptionalArrayRef<int64_t > dim_order,
73+ Tensor& out) {
74+ executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext context{};
75+ return _clone_dim_order_out (context, self, non_blocking, dim_order, out);
76+ }
77+
78+ } // namespace native
79+ } // namespace executor
80+ } // namespace torch
0 commit comments