Skip to content

Commit 7ea55eb

Browse files
committed
Update on "[ExecuTorch] Add broadcast support for optimized add op"
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
2 parents 0ce8fd7 + 00e54b8 commit 7ea55eb

File tree

4 files changed

+24
-60
lines changed

4 files changed

+24
-60
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -49,38 +49,8 @@ enum class ElementwiseOptimizedPath {
4949
kBroadcastLastDimReverseArguments,
5050
};
5151

52-
enum class BinaryOpType {
53-
kAdd,
54-
kSub,
55-
kMul,
56-
kDiv,
57-
};
58-
5952
namespace internal {
6053

61-
template <BinaryOpType op_type>
62-
struct BinaryOpTypeName;
63-
64-
template <>
65-
struct BinaryOpTypeName<BinaryOpType::kAdd> {
66-
static constexpr char kName[] = "add.out";
67-
};
68-
69-
template <>
70-
struct BinaryOpTypeName<BinaryOpType::kSub> {
71-
static constexpr char kName[] = "sub.out";
72-
};
73-
74-
template <>
75-
struct BinaryOpTypeName<BinaryOpType::kMul> {
76-
static constexpr char kName[] = "mul.out";
77-
};
78-
79-
template <>
80-
struct BinaryOpTypeName<BinaryOpType::kDiv> {
81-
static constexpr char kName[] = "div.out";
82-
};
83-
8454
/*
8555
Given two tensors, this function returns the broadcast dim if it exists.
8656
Returns 0 if no broadcast dim is found.
@@ -222,15 +192,15 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
222192
return normalized_tensor_size;
223193
}
224194

225-
template <BinaryOpType op_type, typename Op>
195+
template <const char* op_name, typename Op>
226196
Tensor& handle_last_dim_broadcast_elementwise(
227197
KernelRuntimeContext& ctx,
228198
const Op& vec_fun,
229199
const Tensor& a,
230200
const Tensor& b,
231201
Tensor& out,
232202
const ElementwiseOptimizedPath selected_optimized_path,
233-
executorch::aten::optional<Scalar>& alpha = {}) {
203+
const executorch::aten::optional<Scalar>& alpha = {}) {
234204
ScalarType out_type = out.scalar_type();
235205
const Tensor* lhs;
236206
const Tensor* rhs;
@@ -251,11 +221,11 @@ Tensor& handle_last_dim_broadcast_elementwise(
251221
"Failed to resize output tensor.");
252222
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
253223
const auto broadcast_size = out.size(out.dim() - 1);
254-
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
224+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
255225
using Vec = executorch::vec::Vectorized<CTYPE>;
256-
CTYPE alpha_val;
257-
Vec alpha_val_vec(alpha_val);
226+
Vec alpha_val_vec;
258227
if (alpha.has_value()) {
228+
CTYPE alpha_val;
259229
ET_KERNEL_CHECK(
260230
ctx,
261231
native::utils::extract_scalar(alpha.value(), &alpha_val),
@@ -276,20 +246,20 @@ Tensor& handle_last_dim_broadcast_elementwise(
276246
return out;
277247
}
278248

279-
template <BinaryOpType op_type, typename Op>
249+
template <const char* op_name, typename Op>
280250
Tensor& handle_broadcast_elementwise(
281251
KernelRuntimeContext& ctx,
282252
const Op& vec_fun,
283253
const Tensor& a,
284254
const Tensor& b,
285255
Tensor& out,
286256
const ElementwiseOptimizedPath selected_optimized_path,
287-
executorch::aten::optional<Scalar> alpha = {}) {
257+
const executorch::aten::optional<Scalar>& alpha = {}) {
288258
if ((selected_optimized_path ==
289259
ElementwiseOptimizedPath::kBroadcastLastDim) ||
290260
(selected_optimized_path ==
291261
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
292-
return handle_last_dim_broadcast_elementwise<op_type>(
262+
return handle_last_dim_broadcast_elementwise<op_name>(
293263
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
294264
}
295265

@@ -336,11 +306,11 @@ Tensor& handle_broadcast_elementwise(
336306
broadcast_size = lhs->sizes()[lhs->dim() - 2];
337307
inner_size = lhs->sizes()[lhs->dim() - 1];
338308
}
339-
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
309+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
340310
using Vec = executorch::vec::Vectorized<CTYPE>;
341-
CTYPE alpha_val;
342311
Vec alpha_val_vec;
343312
if (alpha.has_value()) {
313+
CTYPE alpha_val;
344314
ET_KERNEL_CHECK(
345315
ctx,
346316
native::utils::extract_scalar(alpha.value(), &alpha_val),

kernels/optimized/cpu/op_add.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,29 +140,30 @@ Tensor& opt_add_out(
140140
out.numel());
141141
});
142142
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143+
static constexpr const char op_name[] = "add.out";
143144
if (selected_optimized_path ==
144145
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
145146
selected_optimized_path ==
146147
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
147148
selected_optimized_path ==
148149
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
149-
// This behavior is a bit confusing.
150150
// Reason we swap out args here is because handle_broadcast_elementwise
151151
// handles this selected_optimized_path option a bit differently.
152-
// This should really be resoled in handle_broadcast_elementwise.
153-
// However, the current blocker is that handle_broadcast_elementwise tries to
154-
// be agnostic of op. This should be fixed, likely by moving lambda creation
155-
// to handle_broadcast_elementwise and it be aware of which op is being executed.
152+
// This should really be resolved in handle_broadcast_elementwise.
153+
// However, the current blocker is that handle_broadcast_elementwise tries
154+
// to be agnostic of op. This should be fixed, likely by moving lambda
155+
// creation to handle_broadcast_elementwise and it be aware of which op is
156+
// being executed.
156157
auto add_lambda = [](auto x, auto y, auto alpha_val) {
157158
return y + alpha_val * x;
158159
};
159-
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
160+
return torch::executor::handle_broadcast_elementwise<op_name>(
160161
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
161162
} else {
162163
auto add_lambda = [](auto x, auto y, auto alpha_val) {
163164
return x + alpha_val * y;
164165
};
165-
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
166+
return torch::executor::handle_broadcast_elementwise<op_name>(
166167
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
167168
}
168169
} else {

kernels/optimized/cpu/op_mul.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,15 @@ Tensor& opt_mul_out(
130130
out.numel());
131131
});
132132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
// Reason for using alpha:
133+
// Reason for using alpha even when used for mul is becasuse
134+
// handle_broadcast_elementwise is used for add and sub as well
135+
// and it uses alpha.
134136
auto mul_lambda = [](auto x, auto y, auto alpha) {
135-
(void)alpha;
137+
[[maybe_unused]] alpha;
136138
return x * y;
137139
};
138-
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kMul>(
140+
static constexpr const char op_name[] = "mul.out";
141+
return torch::executor::handle_broadcast_elementwise<op_name>(
139142
ctx, mul_lambda, a, b, out, selected_optimized_path);
140143
} else {
141144
ScalarType common_type =

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
417417
test_broadcast_a2b<ScalarType::Int>();
418418
test_broadcast_a2b<ScalarType::Half>();
419419
test_broadcast_a2b<ScalarType::BFloat16>();
420-
421-
// Test 3D tensors
422-
test_broadcast_3D<ScalarType::Float>();
423-
test_broadcast_3D<ScalarType::Half>();
424-
test_broadcast_3D<ScalarType::BFloat16>();
425-
426-
// Test 4D tensors
427-
test_broadcast_4D<ScalarType::Float>();
428-
test_broadcast_4D<ScalarType::Half>();
429-
test_broadcast_4D<ScalarType::BFloat16>();
430420
}
431421

432422
// Broadcast tensor a's size to tensor b's size

0 commit comments

Comments
 (0)