diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 368b1b0d0ea..7dead2bf5a7 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -192,6 +192,36 @@ Tensor& add_scalar_out( return impl::add_scalar_out(ctx, a, b, alpha, out); } +std::tuple< + Error, + std::array, + size_t> +add_out_shape(const Tensor& a, const Tensor& b, ET_UNUSED const Scalar& alpha) { + std::array out_sizes{}; + size_t out_dim = 0; + + Error err = get_broadcast_target_size( + a, b, out_sizes.data(), kTensorDimensionLimit, &out_dim); + + return std::make_tuple(err, out_sizes, out_dim); +} + +std::tuple< + Error, + std::array, + size_t> +add_scalar_out_shape( + const Tensor& a, + ET_UNUSED const Scalar& b, + ET_UNUSED const Scalar& alpha) { + std::array out_sizes{}; + size_t out_dim = a.dim(); + + std::copy(a.sizes().begin(), a.sizes().end(), out_sizes.begin()); + + return std::make_tuple(Error::Ok, out_sizes, out_dim); +} + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_add.h b/kernels/portable/cpu/op_add.h index 3544c7a2e6e..f19d7e98b12 100644 --- a/kernels/portable/cpu/op_add.h +++ b/kernels/portable/cpu/op_add.h @@ -29,6 +29,36 @@ Tensor& add_scalar_out( const Scalar& alpha, Tensor& out); +/** + * Computes the output shape for tensor addition with broadcasting. + * + * @param[in] a First input tensor + * @param[in] b Second input tensor + * @param[in] alpha Scalar multiplier for b (unused for shape computation) + * @return Tuple containing the Error, output shape array, and number of + * dimensions + */ +std::tuple< + Error, + std::array, + size_t> +add_out_shape(const Tensor& a, const Tensor& b, const Scalar& alpha); + +/** + * Computes the output shape for tensor-scalar addition. + * + * @param[in] a Input tensor + * @param[in] b Scalar value (unused for shape computation) + * @param[in] alpha Scalar multiplier for b (unused for shape computation) + * @return Tuple containing the Error, output shape array, and number of + * dimensions + */ +std::tuple< + Error, + std::array, + size_t> +add_scalar_out_shape(const Tensor& a, const Scalar& b, const Scalar& alpha); + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_stack.cpp b/kernels/portable/cpu/op_stack.cpp index 436638b8680..87d419483c0 100644 --- a/kernels/portable/cpu/op_stack.cpp +++ b/kernels/portable/cpu/op_stack.cpp @@ -97,6 +97,49 @@ Tensor& stack_out( return impl::stack_out(ctx, tensors, dim, out); } +std::tuple< + Error, + std::array, + size_t> +stack_out_shape(executorch::aten::ArrayRef tensors, int64_t dim) { + std::array out_sizes{}; + size_t out_dim = 0; + + // Check if tensors array is empty + if (tensors.size() == 0) { + return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); + } + + // Normalize negative dimension + int64_t normalized_dim = dim; + if (normalized_dim < 0) { + normalized_dim += tensors[0].dim() + 1; + } + + // Check if dimension is valid + if (normalized_dim < 0 || normalized_dim > tensors[0].dim()) { + return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); + } + + // Check that all tensors have the same shape + for (size_t i = 1; i < tensors.size(); ++i) { + if (tensors[i].dim() != tensors[0].dim()) { + return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); + } + for (size_t d = 0; d < tensors[0].dim(); ++d) { + if (tensors[i].size(d) != tensors[0].size(d)) { + return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim); + } + } + } + + // Compute output shape using the existing utility + ::torch::executor::get_stack_out_target_size( + tensors, normalized_dim, out_sizes.data(), &out_dim); + + return std::make_tuple(Error::Ok, out_sizes, out_dim); +} + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/op_stack.h b/kernels/portable/cpu/op_stack.h index e1e09d2608a..6a507b7dcd5 100644 --- a/kernels/portable/cpu/op_stack.h +++ b/kernels/portable/cpu/op_stack.h @@ -21,6 +21,20 @@ Tensor& stack_out( int64_t dim, Tensor& out); +/** + * Computes the output shape for tensor stacking. + * + * @param[in] tensors Array of input tensors to stack + * @param[in] dim Dimension along which to stack + * @return Tuple containing the Error, output shape array, and number of + * dimensions + */ +std::tuple< + Error, + std::array, + size_t> +stack_out_shape(executorch::aten::ArrayRef tensors, int64_t dim); + } // namespace utils } // namespace native } // namespace executor