Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
27a79c4
[Executorch] Refactor op_mul's broadcasting utils
kimishpatel Feb 5, 2025
dbe3e8a
[ExecuTorch] Add broadcast support for optimized add op
kimishpatel Feb 5, 2025
bf761db
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 6, 2025
0e1cfc7
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 6, 2025
0ce8fd7
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 6, 2025
00e54b8
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 7, 2025
7ea55eb
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 7, 2025
ffb6903
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 7, 2025
e9fe6af
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 7, 2025
e53eb97
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 11, 2025
a91eef8
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 11, 2025
f565c3b
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 11, 2025
656873f
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 11, 2025
8ecbd04
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 12, 2025
2804f70
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 12, 2025
f3406bf
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 12, 2025
132d2f5
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 12, 2025
216c4be
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 12, 2025
bde7998
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 12, 2025
110a932
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 13, 2025
7ebd165
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 13, 2025
5fb4107
Merge branch 'main' into gh/kimishpatel/154/head
kimishpatel Feb 13, 2025
9e0855b
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 13, 2025
0d19ade
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 13, 2025
8955d90
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 15, 2025
6f2f01a
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions kernels/optimized/cpu/binary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#pragma once

#include <executorch/kernels/optimized/vec/functional.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -190,5 +192,145 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
return normalized_tensor_size;
}

template <const char* op_name, typename Op>
Tensor& handle_last_dim_broadcast_elementwise(
KernelRuntimeContext& ctx,
const Op& vec_fun,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path,
const executorch::aten::optional<Scalar>& alpha = {}) {
ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
lhs = &b;
rhs = &a;
} else {
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
const auto broadcast_size = out.size(out.dim() - 1);
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
Vec alpha_val_vec;
if (alpha.has_value()) {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx,
native::utils::extract_scalar(alpha.value(), &alpha_val),
InvalidArgument, );
alpha_val_vec = Vec(alpha_val);
}
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
return vec_fun(a, b, alpha_val_vec);
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see, you're having problems with the lambda because of this part. you can solve this by factoring the code differently.

the end result at the callsite could look something like

auto broadcast_op_plan_opt = plan_broadcast_elementwise(...); // broadcast_op_plan is a struct containing all the stuff you work out that isn't dependent on the dtype, like lhs, rhs. it does ET_KERNEL_CHECKs intenrally and returns nullopt if they fail.
if (!broadcast_op_plan_opt) {
  // a check already failed
  return;
}
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
  auto alpha_val_vec_opt = extract_scalar_to_vector<CTYPE>(); // wrap up the bit that 
  if (!alpha_val_vec_opt) {
    // awkward that this only returns from the lambda, but this is a generic ET_KERNEL_CHECK problem
    return;
  }
  auto add_lambda = [alpha_val_vec = *alpha_val_vec_opt](auto x, auto y) {
        return y + alpha_val * x;
      };
  execute_broadcast_elementwise_plan<CTYPE>(*broadcast_op_plan_opt, add_lambda, ...);
});

disclaimer: this is off the top of my head and it may be possible to unify some of this stuff with dtype_util.h for further simplification, though dtype_util is mostly intended to cut size/build time of portable ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me see if I dont run into other issues to enable such refactor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so I looked refactor required. I think it is doable at the cost of moving ET_SWITCH_REALB_TYPES macros to the callsite in respective ops. Downside here is that now if you enable new dtype for optimized path, you have to change all the callsites.

So I am not fully convinced that it is better go down that route. But want to see whats your reasoning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have to change all the callsites.

that's just a matter of typing, right? if you plan to do it (I suppose optimizing Half/BFloat16 should be on our TODO list if the hardware supports the relevant instructions) and you really don't want to change 4-5 files later (you'll have to change them anyway for specifically Half/BFloat16 because there are opt-outs), you could always #define ET_SWITCH_OPTIMIZED_ELEMENTWISE_BROADCAST_OP_TYPES ET_SWITCH_REALB_TYPEs pre-emptively.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thats fair. But is your reasoning for this change simpler code or you see perf impact.

I am not too stuck to it, so I will just go ahead and do it but wanted to understand your reasoning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simpler less repetitive code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will make the change but this will likely marginally increase size since now the whole handle_broadcast_elementwise function is dtype specialized

executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
vec_fun_alpha,
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size);
});
return out;
}

template <const char* op_name, typename Op>
Tensor& handle_broadcast_elementwise(
KernelRuntimeContext& ctx,
const Op& vec_fun,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path,
const executorch::aten::optional<Scalar>& alpha = {}) {
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDim) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
return handle_last_dim_broadcast_elementwise<op_name>(
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
}

ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNd));
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
int64_t outer_size = 1;
int64_t broadcast_size;
int64_t inner_size;
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
auto normalized_tensor_size_lhs =
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
outer_size = normalized_tensor_size_lhs[0];
broadcast_size = normalized_tensor_size_lhs[1];
inner_size = normalized_tensor_size_lhs[2];
} else {
broadcast_size = lhs->sizes()[lhs->dim() - 2];
inner_size = lhs->sizes()[lhs->dim() - 1];
}
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
Vec alpha_val_vec;
if (alpha.has_value()) {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx,
native::utils::extract_scalar(alpha.value(), &alpha_val),
InvalidArgument, );
alpha_val_vec = Vec(alpha_val);
}
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
return vec_fun(a, b, alpha_val_vec);
};
executorch::vec::
broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>(
vec_fun_alpha,
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size,
inner_size);
});
return out;
}
} // namespace executor
} // namespace torch
55 changes: 23 additions & 32 deletions kernels/optimized/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,41 +140,32 @@ Tensor& opt_add_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
static constexpr const char op_name[] = "add.out";
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
lhs = &b;
rhs = &a;
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
// Reason we swap out args here is because handle_broadcast_elementwise
// handles this selected_optimized_path option a bit differently.
// This should really be resolved in handle_broadcast_elementwise.
// However, the current blocker is that handle_broadcast_elementwise tries
// to be agnostic of op. This should be fixed, likely by moving lambda
// creation to handle_broadcast_elementwise and it be aware of which op is
// being executed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I agree, but we can settle that on a review of a proposed change

auto add_lambda = [](auto x, auto y, auto alpha_val) {
return y + alpha_val * x;
};
return torch::executor::handle_broadcast_elementwise<op_name>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d);
lhs = &a;
rhs = &b;
auto add_lambda = [](auto x, auto y, auto alpha_val) {
return x + alpha_val * y;
};
return torch::executor::handle_broadcast_elementwise<op_name>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
});
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
Expand Down
118 changes: 9 additions & 109 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,114 +68,6 @@ template <
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

Tensor& handle_last_dim_broadcast(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
lhs = &b;
rhs = &a;
} else {
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
const auto broadcast_size = out.size(out.dim() - 1);
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size);
});
return out;
}

Tensor& handle_broadcast_mul(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDim) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
}

ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNd));
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
int64_t outer_size = 1;
int64_t broadcast_size;
int64_t inner_size;
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
auto normalized_tensor_size_lhs =
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
outer_size = normalized_tensor_size_lhs[0];
broadcast_size = normalized_tensor_size_lhs[1];
inner_size = normalized_tensor_size_lhs[2];
} else {
broadcast_size = lhs->sizes()[lhs->dim() - 2];
inner_size = lhs->sizes()[lhs->dim() - 1];
}
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size,
inner_size);
});
return out;
}
} // namespace

Tensor& opt_mul_out(
Expand Down Expand Up @@ -238,7 +130,15 @@ Tensor& opt_mul_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
// Reason for using alpha even when used for mul is becasuse
// handle_broadcast_elementwise is used for add and sub as well
// and it uses alpha.
auto mul_lambda = [](auto x, auto y, [[maybe_unused]] auto alpha) {
return x * y;
};
static constexpr const char op_name[] = "mul.out";
return torch::executor::handle_broadcast_elementwise<op_name>(
ctx, mul_lambda, a, b, out, selected_optimized_path);
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
Expand Down
Loading
Loading