diff --git a/kernels/optimized/cpu/binary_ops.cpp b/kernels/optimized/cpu/binary_ops.cpp new file mode 100644 index 00000000000..46d8e473051 --- /dev/null +++ b/kernels/optimized/cpu/binary_ops.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch::executor::internal { +std::optional plan_broadcast_elementwise( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + BroadcastElementwisePlan plan; + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + plan.lhs = &b; + plan.rhs = &a; + } else { + // Catch failure to update logic when adding new broadcasting possibility. + ET_DCHECK( + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd)); + plan.lhs = &a; + plan.rhs = &b; + } + auto error = resize_tensor(out, plan.lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + std::nullopt, + "Failed to resize output tensor."); + plan.outer_size = 1; + if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*plan.lhs, *plan.rhs); + int32_t broadcast_dim_lhs = plan.lhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = + get_normalized_tensor_size(*plan.lhs, broadcast_dim_lhs); + plan.outer_size = normalized_tensor_size_lhs[0]; + plan.broadcast_size = normalized_tensor_size_lhs[1]; + plan.inner_size = normalized_tensor_size_lhs[2]; + } else { + plan.broadcast_size = plan.lhs->sizes()[plan.lhs->dim() - 2]; + plan.inner_size = plan.lhs->sizes()[plan.lhs->dim() - 1]; + } + return plan; +} +} // namespace torch::executor::internal diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index dd4425e4ce6..acd90e6f86f 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -13,6 +13,8 @@ #include #include +#include + namespace torch { namespace executor { enum class ElementwiseOptimizedPath { @@ -206,6 +208,23 @@ Tensor& handle_last_dim_broadcast_elementwise( return out; } +namespace internal { +struct BroadcastElementwisePlan { + const Tensor* lhs; + const Tensor* rhs; + int64_t outer_size; + int64_t broadcast_size; + int64_t inner_size; +}; + +std::optional plan_broadcast_elementwise( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path); +} // namespace internal + template Tensor& handle_broadcast_elementwise( KernelRuntimeContext& ctx, @@ -223,56 +242,19 @@ Tensor& handle_broadcast_elementwise( ctx, vec_fun, a, b, out, selected_optimized_path); } - const Tensor* lhs; - const Tensor* rhs; - if ((selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when adding new broadcasting possibility. - ET_DCHECK( - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNd)); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - int64_t outer_size = 1; - int64_t broadcast_size; - int64_t inner_size; - if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); - int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; - auto normalized_tensor_size_lhs = - get_normalized_tensor_size(*lhs, broadcast_dim_lhs); - outer_size = normalized_tensor_size_lhs[0]; - broadcast_size = normalized_tensor_size_lhs[1]; - inner_size = normalized_tensor_size_lhs[2]; - } else { - broadcast_size = lhs->sizes()[lhs->dim() - 2]; - inner_size = lhs->sizes()[lhs->dim() - 1]; + auto opt_plan = internal::plan_broadcast_elementwise( + ctx, a, b, out, selected_optimized_path); + if (!opt_plan) { + return out; } executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( vec_fun, out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - outer_size, - broadcast_size, - inner_size); + opt_plan->lhs->const_data_ptr(), + opt_plan->rhs->const_data_ptr(), + opt_plan->outer_size, + opt_plan->broadcast_size, + opt_plan->inner_size); return out; } } // namespace executor diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 62c47c6256f..12e20edaaad 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -42,6 +42,7 @@ def define_common_targets(): runtime.cxx_library( name = "binary_ops", + srcs = ["binary_ops.cpp"], exported_headers = ["binary_ops.h"], visibility = ["//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS",], exported_deps = ["//executorch/runtime/core:core"],