Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions kernels/optimized/cpu/binary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ enum class ElementwiseOptimizedPath {
kBroadcast2dBy1dReverseArguments,
kBroadcastNdByNd,
kBroadcastNdByNdReverseArguments,
kBroadcastLastDim,
kBroadcastLastDimReverseArguments,
};

namespace internal {
Expand Down Expand Up @@ -117,6 +119,12 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
} else {
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
}
} else if (broadcast_dim == -1) {
if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) { return x == 1; }) == 1) {
return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments;
} else {
return ElementwiseOptimizedPath::kBroadcastLastDim;
}
}
return ElementwiseOptimizedPath::kNone;
}
Expand Down
163 changes: 113 additions & 50 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/kernels/optimized/vec/vec.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

Expand Down Expand Up @@ -66,6 +67,117 @@ template <
typename CTYPE_OUT>
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

Tensor& handle_last_dim_broadcast(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
lhs = &b;
rhs = &a;
} else {
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
const auto broadcast_size = out.size(out.dim() - 1);
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size);
});
return out;
}

Tensor& handle_broadcast_mul(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {

if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDim) ||
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
return handle_last_dim_broadcast(
ctx, a, b, out, selected_optimized_path);
}

ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNd));
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
int64_t outer_size = 1;
int64_t broadcast_size;
int64_t inner_size;
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
auto normalized_tensor_size_lhs =
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
outer_size = normalized_tensor_size_lhs[0];
broadcast_size = normalized_tensor_size_lhs[1];
inner_size = normalized_tensor_size_lhs[2];
} else {
broadcast_size = lhs->sizes()[lhs->dim() - 2];
inner_size = lhs->sizes()[lhs->dim() - 1];
}
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size,
inner_size);
});
return out;
}
} // namespace

Tensor& opt_mul_out(
Expand Down Expand Up @@ -128,56 +240,7 @@ Tensor& opt_mul_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd));
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
int64_t outer_size = 1;
int64_t broadcast_size;
int64_t inner_size;
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
auto normalized_tensor_size_lhs = get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
outer_size = normalized_tensor_size_lhs[0];
broadcast_size = normalized_tensor_size_lhs[1];
inner_size = normalized_tensor_size_lhs[2];
} else {
broadcast_size = lhs->sizes()[lhs->dim() - 2];
inner_size = lhs->sizes()[lhs->dim() - 1];
}
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size,
inner_size);
});
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
Expand Down
1 change: 1 addition & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ _OPTIMIZED_ATEN_OPS = (
":binary_ops",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
),
op_target(
Expand Down
29 changes: 29 additions & 0 deletions kernels/optimized/vec/functional_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,34 @@ inline void broadcasting_map_2d_by_1d(
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
}

template <typename scalar_t, typename Op>
inline void broadcasting_map_broadcast_last_dim(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* lhs,
const scalar_t* rhs,
int64_t outer_size,
int64_t broadcast_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t outer_stride_lhs = broadcast_size;
int64_t outer_stride_rhs = 1;
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
int64_t inner_idx = 0;
Vec data_vec2 = Vec(rhs[outer_idx]);
for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx);
}
if (broadcast_size - inner_idx > 0) {
Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
}
}
}

} // namespace vec
} // namespace executorch
Loading