|
6 | 6 | * LICENSE file in the root directory of this source tree.
|
7 | 7 | */
|
8 | 8 |
|
9 |
| -#include <cstring> |
10 |
| - |
11 |
| -#include <c10/util/irange.h> |
12 |
| -#include <executorch/kernels/portable/cpu/util/copy_ops_util.h> |
| 9 | +#include <executorch/kernels/portable/cpu/util/stack_util.h> |
13 | 10 | #include <executorch/runtime/kernel/kernel_includes.h>
|
14 | 11 |
|
15 | 12 | namespace torch {
|
16 | 13 | namespace executor {
|
17 | 14 | namespace native {
|
18 |
| -namespace impl { |
19 |
| - |
20 |
| -using Tensor = executorch::aten::Tensor; |
21 | 15 |
|
22 | 16 | Tensor& stack_out(
|
23 | 17 | KernelRuntimeContext& ctx,
|
24 | 18 | executorch::aten::ArrayRef<Tensor> tensors,
|
25 | 19 | int64_t dim,
|
26 | 20 | Tensor& out) {
|
27 |
| - (void)ctx; |
28 |
| - |
29 |
| - if (dim < 0) { |
30 |
| - dim += out.dim(); |
31 |
| - } |
32 |
| - |
33 |
| - ET_KERNEL_CHECK( |
34 |
| - ctx, check_stack_args(tensors, dim, out), InvalidArgument, out); |
35 |
| - |
36 |
| - for (size_t i = 0; i < tensors.size(); ++i) { |
37 |
| - ET_KERNEL_CHECK( |
38 |
| - ctx, |
39 |
| - tensors_have_same_dim_order(tensors[i], out), |
40 |
| - InvalidArgument, |
41 |
| - out); |
42 |
| - } |
43 |
| - |
44 |
| - ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out); |
45 |
| - |
46 |
| - Tensor::SizesType expected_out_size[kTensorDimensionLimit]; |
47 |
| - size_t expected_out_dim = 0; |
48 |
| - get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim); |
49 |
| - ET_KERNEL_CHECK( |
50 |
| - ctx, |
51 |
| - resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok, |
52 |
| - InvalidArgument, |
53 |
| - out); |
54 |
| - |
55 |
| - const size_t outer = getLeadingDims(out, dim); |
56 |
| - const size_t inner = getTrailingDims(out, dim); |
57 |
| - const size_t ninputs = tensors.size(); |
58 |
| - |
59 |
| - const auto out_type = out.scalar_type(); |
60 |
| - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "stack.out", CTYPE_OUT, [&] { |
61 |
| - CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>(); |
62 |
| - for (size_t i = 0; i < outer; ++i) { |
63 |
| - for (size_t j = 0; j < ninputs; ++j) { |
64 |
| - const auto in_type = tensors[j].scalar_type(); |
65 |
| - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "stack.out", CTYPE_IN, [&] { |
66 |
| - const CTYPE_IN* const in_ptr = |
67 |
| - tensors[j].const_data_ptr<CTYPE_IN>() + i * inner; |
68 |
| - |
69 |
| - for (size_t k = 0; k < inner; ++k) { |
70 |
| - out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]); |
71 |
| - } |
72 |
| - out_ptr += inner; |
73 |
| - }); |
74 |
| - } |
75 |
| - } |
76 |
| - }); |
77 |
| - |
78 |
| - return out; |
79 |
| -} |
80 |
| - |
81 |
| -} // namespace impl |
82 |
| - |
83 |
| -Tensor& stack_out( |
84 |
| - KernelRuntimeContext& ctx, |
85 |
| - executorch::aten::ArrayRef<Tensor> tensors, |
86 |
| - int64_t dim, |
87 |
| - Tensor& out) { |
88 |
| - return impl::stack_out(ctx, tensors, dim, out); |
89 |
| -} |
90 |
| - |
91 |
| -namespace utils { |
92 |
| - |
93 |
| -Tensor& stack_out( |
94 |
| - KernelRuntimeContext& ctx, |
95 |
| - executorch::aten::ArrayRef<Tensor> tensors, |
96 |
| - int64_t dim, |
97 |
| - Tensor& out) { |
98 |
| - return impl::stack_out(ctx, tensors, dim, out); |
99 |
| -} |
100 |
| - |
101 |
| -std::tuple< |
102 |
| - Error, |
103 |
| - std::array<executorch::aten::SizesType, kTensorDimensionLimit>, |
104 |
| - size_t> |
105 |
| -stack_out_shape(executorch::aten::ArrayRef<Tensor> tensors, int64_t dim) { |
106 |
| - std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{}; |
107 |
| - size_t out_dim = 0; |
108 |
| - |
109 |
| - // Check if tensors array is empty |
110 |
| - if (tensors.size() == 0) { |
111 |
| - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
112 |
| - } |
113 |
| - |
114 |
| - // Normalize negative dimension |
115 |
| - int64_t normalized_dim = dim; |
116 |
| - if (normalized_dim < 0) { |
117 |
| - normalized_dim += tensors[0].dim() + 1; |
118 |
| - } |
119 |
| - |
120 |
| - // Check if dimension is valid |
121 |
| - if (normalized_dim < 0 || normalized_dim > tensors[0].dim()) { |
122 |
| - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
123 |
| - } |
124 |
| - |
125 |
| - // Check that all tensors have the same shape |
126 |
| - for (size_t i = 1; i < tensors.size(); ++i) { |
127 |
| - if (tensors[i].dim() != tensors[0].dim()) { |
128 |
| - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
129 |
| - } |
130 |
| - for (const auto d : c10::irange(tensors[0].dim())) { |
131 |
| - if (tensors[i].size(d) != tensors[0].size(d)) { |
132 |
| - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
133 |
| - } |
134 |
| - } |
135 |
| - } |
136 |
| - |
137 |
| - // Compute output shape using the existing utility |
138 |
| - ::torch::executor::get_stack_out_target_size( |
139 |
| - tensors, normalized_dim, out_sizes.data(), &out_dim); |
140 |
| - |
141 |
| - return std::make_tuple(Error::Ok, out_sizes, out_dim); |
| 21 | + return utils::stack_out_impl(ctx, tensors, dim, out); |
142 | 22 | }
|
143 | 23 |
|
144 |
| -} // namespace utils |
145 | 24 | } // namespace native
|
146 | 25 | } // namespace executor
|
147 | 26 | } // namespace torch
|
0 commit comments