Skip to content

Commit ebfdcb9

Browse files
Introduce out shape utils (add/stack) (#13199)
Reviewed By: georgehong Differential Revision: D79826201
1 parent bc5186c commit ebfdcb9

File tree

4 files changed

+117
-0
lines changed

4 files changed

+117
-0
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,36 @@ Tensor& add_scalar_out(
192192
return impl::add_scalar_out(ctx, a, b, alpha, out);
193193
}
194194

195+
std::tuple<
196+
Error,
197+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
198+
size_t>
199+
add_out_shape(const Tensor& a, const Tensor& b, ET_UNUSED const Scalar& alpha) {
200+
std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{};
201+
size_t out_dim = 0;
202+
203+
Error err = get_broadcast_target_size(
204+
a, b, out_sizes.data(), kTensorDimensionLimit, &out_dim);
205+
206+
return std::make_tuple(err, out_sizes, out_dim);
207+
}
208+
209+
std::tuple<
210+
Error,
211+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
212+
size_t>
213+
add_scalar_out_shape(
214+
const Tensor& a,
215+
ET_UNUSED const Scalar& b,
216+
ET_UNUSED const Scalar& alpha) {
217+
std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{};
218+
size_t out_dim = a.dim();
219+
220+
std::copy(a.sizes().begin(), a.sizes().end(), out_sizes.begin());
221+
222+
return std::make_tuple(Error::Ok, out_sizes, out_dim);
223+
}
224+
195225
} // namespace utils
196226
} // namespace native
197227
} // namespace executor

kernels/portable/cpu/op_add.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,36 @@ Tensor& add_scalar_out(
2929
const Scalar& alpha,
3030
Tensor& out);
3131

32+
/**
33+
* Computes the output shape for tensor addition with broadcasting.
34+
*
35+
* @param[in] a First input tensor
36+
* @param[in] b Second input tensor
37+
* @param[in] alpha Scalar multiplier for b (unused for shape computation)
38+
* @return Tuple containing the Error, output shape array, and number of
39+
* dimensions
40+
*/
41+
std::tuple<
42+
Error,
43+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
44+
size_t>
45+
add_out_shape(const Tensor& a, const Tensor& b, const Scalar& alpha);
46+
47+
/**
48+
* Computes the output shape for tensor-scalar addition.
49+
*
50+
* @param[in] a Input tensor
51+
* @param[in] b Scalar value (unused for shape computation)
52+
* @param[in] alpha Scalar multiplier for b (unused for shape computation)
53+
* @return Tuple containing the Error, output shape array, and number of
54+
* dimensions
55+
*/
56+
std::tuple<
57+
Error,
58+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
59+
size_t>
60+
add_scalar_out_shape(const Tensor& a, const Scalar& b, const Scalar& alpha);
61+
3262
} // namespace utils
3363
} // namespace native
3464
} // namespace executor

kernels/portable/cpu/op_stack.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,49 @@ Tensor& stack_out(
9797
return impl::stack_out(ctx, tensors, dim, out);
9898
}
9999

100+
std::tuple<
101+
Error,
102+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
103+
size_t>
104+
stack_out_shape(executorch::aten::ArrayRef<Tensor> tensors, int64_t dim) {
105+
std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{};
106+
size_t out_dim = 0;
107+
108+
// Check if tensors array is empty
109+
if (tensors.size() == 0) {
110+
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
111+
}
112+
113+
// Normalize negative dimension
114+
int64_t normalized_dim = dim;
115+
if (normalized_dim < 0) {
116+
normalized_dim += tensors[0].dim() + 1;
117+
}
118+
119+
// Check if dimension is valid
120+
if (normalized_dim < 0 || normalized_dim > tensors[0].dim()) {
121+
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
122+
}
123+
124+
// Check that all tensors have the same shape
125+
for (size_t i = 1; i < tensors.size(); ++i) {
126+
if (tensors[i].dim() != tensors[0].dim()) {
127+
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
128+
}
129+
for (size_t d = 0; d < tensors[0].dim(); ++d) {
130+
if (tensors[i].size(d) != tensors[0].size(d)) {
131+
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
132+
}
133+
}
134+
}
135+
136+
// Compute output shape using the existing utility
137+
::torch::executor::get_stack_out_target_size(
138+
tensors, normalized_dim, out_sizes.data(), &out_dim);
139+
140+
return std::make_tuple(Error::Ok, out_sizes, out_dim);
141+
}
142+
100143
} // namespace utils
101144
} // namespace native
102145
} // namespace executor

kernels/portable/cpu/op_stack.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ Tensor& stack_out(
2121
int64_t dim,
2222
Tensor& out);
2323

24+
/**
25+
* Computes the output shape for tensor stacking.
26+
*
27+
* @param[in] tensors Array of input tensors to stack
28+
* @param[in] dim Dimension along which to stack
29+
* @return Tuple containing the Error, output shape array, and number of
30+
* dimensions
31+
*/
32+
std::tuple<
33+
Error,
34+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
35+
size_t>
36+
stack_out_shape(executorch::aten::ArrayRef<Tensor> tensors, int64_t dim);
37+
2438
} // namespace utils
2539
} // namespace native
2640
} // namespace executor

0 commit comments

Comments
 (0)