|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
9 | | -#include <cstring> |
10 | 9 |
|
11 | | -#include <c10/util/irange.h> |
12 | | -#include <executorch/kernels/portable/cpu/util/copy_ops_util.h> |
| 10 | +#include <executorch/kernels/portable/cpu/util/stack_util.h> |
13 | 11 | #include <executorch/runtime/kernel/kernel_includes.h> |
14 | 12 |
|
15 | 13 | namespace torch { |
16 | 14 | namespace executor { |
17 | 15 | namespace native { |
18 | | -namespace impl { |
19 | | - |
20 | | -using Tensor = executorch::aten::Tensor; |
21 | | - |
22 | | -Tensor& stack_out( |
23 | | - KernelRuntimeContext& ctx, |
24 | | - executorch::aten::ArrayRef<Tensor> tensors, |
25 | | - int64_t dim, |
26 | | - Tensor& out) { |
27 | | - (void)ctx; |
28 | | - |
29 | | - if (dim < 0) { |
30 | | - dim += out.dim(); |
31 | | - } |
32 | | - |
33 | | - ET_KERNEL_CHECK( |
34 | | - ctx, check_stack_args(tensors, dim, out), InvalidArgument, out); |
35 | | - |
36 | | - for (size_t i = 0; i < tensors.size(); ++i) { |
37 | | - ET_KERNEL_CHECK( |
38 | | - ctx, |
39 | | - tensors_have_same_dim_order(tensors[i], out), |
40 | | - InvalidArgument, |
41 | | - out); |
42 | | - } |
43 | | - |
44 | | - ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out); |
45 | | - |
46 | | - Tensor::SizesType expected_out_size[kTensorDimensionLimit]; |
47 | | - size_t expected_out_dim = 0; |
48 | | - get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim); |
49 | | - ET_KERNEL_CHECK( |
50 | | - ctx, |
51 | | - resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok, |
52 | | - InvalidArgument, |
53 | | - out); |
54 | | - |
55 | | - const size_t outer = getLeadingDims(out, dim); |
56 | | - const size_t inner = getTrailingDims(out, dim); |
57 | | - const size_t ninputs = tensors.size(); |
58 | | - |
59 | | - const auto out_type = out.scalar_type(); |
60 | | - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "stack.out", CTYPE_OUT, [&] { |
61 | | - CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>(); |
62 | | - for (size_t i = 0; i < outer; ++i) { |
63 | | - for (size_t j = 0; j < ninputs; ++j) { |
64 | | - const auto in_type = tensors[j].scalar_type(); |
65 | | - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "stack.out", CTYPE_IN, [&] { |
66 | | - const CTYPE_IN* const in_ptr = |
67 | | - tensors[j].const_data_ptr<CTYPE_IN>() + i * inner; |
68 | | - |
69 | | - for (size_t k = 0; k < inner; ++k) { |
70 | | - out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]); |
71 | | - } |
72 | | - out_ptr += inner; |
73 | | - }); |
74 | | - } |
75 | | - } |
76 | | - }); |
77 | | - |
78 | | - return out; |
79 | | -} |
80 | | - |
81 | | -} // namespace impl |
82 | | - |
83 | | -Tensor& stack_out( |
84 | | - KernelRuntimeContext& ctx, |
85 | | - executorch::aten::ArrayRef<Tensor> tensors, |
86 | | - int64_t dim, |
87 | | - Tensor& out) { |
88 | | - return impl::stack_out(ctx, tensors, dim, out); |
89 | | -} |
90 | | - |
91 | | -namespace utils { |
92 | 16 |
|
93 | 17 | Tensor& stack_out( |
94 | 18 | KernelRuntimeContext& ctx, |
95 | 19 | executorch::aten::ArrayRef<Tensor> tensors, |
96 | 20 | int64_t dim, |
97 | 21 | Tensor& out) { |
98 | | - return impl::stack_out(ctx, tensors, dim, out); |
99 | | -} |
100 | | - |
101 | | -std::tuple< |
102 | | - Error, |
103 | | - std::array<executorch::aten::SizesType, kTensorDimensionLimit>, |
104 | | - size_t> |
105 | | -stack_out_shape(executorch::aten::ArrayRef<Tensor> tensors, int64_t dim) { |
106 | | - std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{}; |
107 | | - size_t out_dim = 0; |
108 | | - |
109 | | - // Check if tensors array is empty |
110 | | - if (tensors.size() == 0) { |
111 | | - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
112 | | - } |
113 | | - |
114 | | - // Normalize negative dimension |
115 | | - int64_t normalized_dim = dim; |
116 | | - if (normalized_dim < 0) { |
117 | | - normalized_dim += tensors[0].dim() + 1; |
118 | | - } |
119 | | - |
120 | | - // Check if dimension is valid |
121 | | - if (normalized_dim < 0 || normalized_dim > tensors[0].dim()) { |
122 | | - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
123 | | - } |
124 | | - |
125 | | - // Check that all tensors have the same shape |
126 | | - for (size_t i = 1; i < tensors.size(); ++i) { |
127 | | - if (tensors[i].dim() != tensors[0].dim()) { |
128 | | - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
129 | | - } |
130 | | - for (const auto d : c10::irange(tensors[0].dim())) { |
131 | | - if (tensors[i].size(d) != tensors[0].size(d)) { |
132 | | - return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); |
133 | | - } |
134 | | - } |
135 | | - } |
136 | | - |
137 | | - // Compute output shape using the existing utility |
138 | | - ::torch::executor::get_stack_out_target_size( |
139 | | - tensors, normalized_dim, out_sizes.data(), &out_dim); |
140 | | - |
141 | | - return std::make_tuple(Error::Ok, out_sizes, out_dim); |
| 22 | + return utils::stack_out_impl(ctx, tensors, dim, out); |
142 | 23 | } |
143 | 24 |
|
144 | | -} // namespace utils |
145 | 25 | } // namespace native |
146 | 26 | } // namespace executor |
147 | 27 | } // namespace torch |
0 commit comments