Skip to content

Commit d6be5a2

Browse files
committed
Update base for Update on "[ExecuTorch] Add broadcasting support to optimized op_div"
Summary: Similar to broadcast support in op_mul Test Plan: Tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
1 parent 6c722bb commit d6be5a2

File tree

8 files changed

+188
-259
lines changed

8 files changed

+188
-259
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -49,38 +49,8 @@ enum class ElementwiseOptimizedPath {
4949
kBroadcastLastDimReverseArguments,
5050
};
5151

52-
enum class BinaryOpType {
53-
kAdd,
54-
kSub,
55-
kMul,
56-
kDiv,
57-
};
58-
5952
namespace internal {
6053

61-
template <BinaryOpType op_type>
62-
struct BinaryOpTypeName;
63-
64-
template <>
65-
struct BinaryOpTypeName<BinaryOpType::kAdd> {
66-
static constexpr char kName[] = "add.out";
67-
};
68-
69-
template <>
70-
struct BinaryOpTypeName<BinaryOpType::kSub> {
71-
static constexpr char kName[] = "sub.out";
72-
};
73-
74-
template <>
75-
struct BinaryOpTypeName<BinaryOpType::kMul> {
76-
static constexpr char kName[] = "mul.out";
77-
};
78-
79-
template <>
80-
struct BinaryOpTypeName<BinaryOpType::kDiv> {
81-
static constexpr char kName[] = "div.out";
82-
};
83-
8454
/*
8555
Given two tensors, this function returns the broadcast dim if it exists.
8656
Returns 0 if no broadcast dim is found.
@@ -222,15 +192,15 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
222192
return normalized_tensor_size;
223193
}
224194

225-
template <BinaryOpType op_type, typename Op>
195+
template <const char* op_name, typename Op>
226196
Tensor& handle_last_dim_broadcast_elementwise(
227197
KernelRuntimeContext& ctx,
228198
const Op& vec_fun,
229199
const Tensor& a,
230200
const Tensor& b,
231201
Tensor& out,
232202
const ElementwiseOptimizedPath selected_optimized_path,
233-
executorch::aten::optional<Scalar>& alpha = {}) {
203+
const executorch::aten::optional<Scalar>& alpha = {}) {
234204
ScalarType out_type = out.scalar_type();
235205
const Tensor* lhs;
236206
const Tensor* rhs;
@@ -251,11 +221,11 @@ Tensor& handle_last_dim_broadcast_elementwise(
251221
"Failed to resize output tensor.");
252222
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
253223
const auto broadcast_size = out.size(out.dim() - 1);
254-
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
224+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
255225
using Vec = executorch::vec::Vectorized<CTYPE>;
256-
CTYPE alpha_val;
257-
Vec alpha_val_vec(alpha_val);
226+
Vec alpha_val_vec;
258227
if (alpha.has_value()) {
228+
CTYPE alpha_val;
259229
ET_KERNEL_CHECK(
260230
ctx,
261231
native::utils::extract_scalar(alpha.value(), &alpha_val),
@@ -276,20 +246,20 @@ Tensor& handle_last_dim_broadcast_elementwise(
276246
return out;
277247
}
278248

279-
template <BinaryOpType op_type, typename Op>
249+
template <const char* op_name, typename Op>
280250
Tensor& handle_broadcast_elementwise(
281251
KernelRuntimeContext& ctx,
282252
const Op& vec_fun,
283253
const Tensor& a,
284254
const Tensor& b,
285255
Tensor& out,
286256
const ElementwiseOptimizedPath selected_optimized_path,
287-
executorch::aten::optional<Scalar> alpha = {}) {
257+
const executorch::aten::optional<Scalar>& alpha = {}) {
288258
if ((selected_optimized_path ==
289259
ElementwiseOptimizedPath::kBroadcastLastDim) ||
290260
(selected_optimized_path ==
291261
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
292-
return handle_last_dim_broadcast_elementwise<op_type>(
262+
return handle_last_dim_broadcast_elementwise<op_name>(
293263
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
294264
}
295265

@@ -336,11 +306,11 @@ Tensor& handle_broadcast_elementwise(
336306
broadcast_size = lhs->sizes()[lhs->dim() - 2];
337307
inner_size = lhs->sizes()[lhs->dim() - 1];
338308
}
339-
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
309+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
340310
using Vec = executorch::vec::Vectorized<CTYPE>;
341-
CTYPE alpha_val;
342311
Vec alpha_val_vec;
343312
if (alpha.has_value()) {
313+
CTYPE alpha_val;
344314
ET_KERNEL_CHECK(
345315
ctx,
346316
native::utils::extract_scalar(alpha.value(), &alpha_val),

kernels/optimized/cpu/op_add.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ Tensor& opt_add_out(
6767
return opt_add_out(ctx, b, a, alpha, out);
6868
}
6969

70-
return torch::executor::kernels::impl::opt_add_sub_out_impl(
70+
static constexpr const char op_name[] = "add.out";
71+
return torch::executor::kernels::impl::opt_add_sub_out_impl<false, op_name>(
7172
ctx, a, b, alpha, out);
7273
}
7374

kernels/optimized/cpu/op_add_sub_impl.cpp

Lines changed: 0 additions & 200 deletions
This file was deleted.

0 commit comments

Comments
 (0)