Skip to content

Commit a91eef8

Browse files
committed
Update on "[ExecuTorch] Add broadcast support for optimized add op"
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
2 parents e9fe6af + e53eb97 commit a91eef8

File tree

3 files changed

+62
-86
lines changed

3 files changed

+62
-86
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
192192
return normalized_tensor_size;
193193
}
194194

195-
template <const char* op_name, typename Op>
195+
template <typename CTYPE, typename Op>
196196
Tensor& handle_last_dim_broadcast_elementwise(
197197
KernelRuntimeContext& ctx,
198198
const Op& vec_fun,
@@ -221,32 +221,17 @@ Tensor& handle_last_dim_broadcast_elementwise(
221221
"Failed to resize output tensor.");
222222
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
223223
const auto broadcast_size = out.size(out.dim() - 1);
224-
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
225-
using Vec = executorch::vec::Vectorized<CTYPE>;
226-
Vec alpha_val_vec;
227-
if (alpha.has_value()) {
228-
CTYPE alpha_val;
229-
ET_KERNEL_CHECK(
230-
ctx,
231-
native::utils::extract_scalar(alpha.value(), &alpha_val),
232-
InvalidArgument, );
233-
alpha_val_vec = Vec(alpha_val);
234-
}
235-
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
236-
return vec_fun(a, b, alpha_val_vec);
237-
};
238-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
239-
vec_fun_alpha,
240-
out.mutable_data_ptr<CTYPE>(),
241-
lhs->const_data_ptr<CTYPE>(),
242-
rhs->const_data_ptr<CTYPE>(),
243-
outer_size,
244-
broadcast_size);
245-
});
224+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
225+
vec_fun,
226+
out.mutable_data_ptr<CTYPE>(),
227+
lhs->const_data_ptr<CTYPE>(),
228+
rhs->const_data_ptr<CTYPE>(),
229+
outer_size,
230+
broadcast_size);
246231
return out;
247232
}
248233

249-
template <const char* op_name, typename Op>
234+
template <typename CTYPE, typename Op>
250235
Tensor& handle_broadcast_elementwise(
251236
KernelRuntimeContext& ctx,
252237
const Op& vec_fun,
@@ -259,11 +244,10 @@ Tensor& handle_broadcast_elementwise(
259244
ElementwiseOptimizedPath::kBroadcastLastDim) ||
260245
(selected_optimized_path ==
261246
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
262-
return handle_last_dim_broadcast_elementwise<op_name>(
263-
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
247+
return handle_last_dim_broadcast_elementwise<CTYPE>(
248+
ctx, vec_fun, a, b, out, selected_optimized_path);
264249
}
265250

266-
ScalarType out_type = out.scalar_type();
267251
const Tensor* lhs;
268252
const Tensor* rhs;
269253
if ((selected_optimized_path ==
@@ -306,30 +290,14 @@ Tensor& handle_broadcast_elementwise(
306290
broadcast_size = lhs->sizes()[lhs->dim() - 2];
307291
inner_size = lhs->sizes()[lhs->dim() - 1];
308292
}
309-
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
310-
using Vec = executorch::vec::Vectorized<CTYPE>;
311-
Vec alpha_val_vec;
312-
if (alpha.has_value()) {
313-
CTYPE alpha_val;
314-
ET_KERNEL_CHECK(
315-
ctx,
316-
native::utils::extract_scalar(alpha.value(), &alpha_val),
317-
InvalidArgument, );
318-
alpha_val_vec = Vec(alpha_val);
319-
}
320-
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
321-
return vec_fun(a, b, alpha_val_vec);
322-
};
323-
executorch::vec::
324-
broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>(
325-
vec_fun_alpha,
326-
out.mutable_data_ptr<CTYPE>(),
327-
lhs->const_data_ptr<CTYPE>(),
328-
rhs->const_data_ptr<CTYPE>(),
329-
outer_size,
330-
broadcast_size,
331-
inner_size);
332-
});
293+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
294+
vec_fun,
295+
out.mutable_data_ptr<CTYPE>(),
296+
lhs->const_data_ptr<CTYPE>(),
297+
rhs->const_data_ptr<CTYPE>(),
298+
outer_size,
299+
broadcast_size,
300+
inner_size);
333301
return out;
334302
}
335303
} // namespace executor

kernels/optimized/cpu/op_add.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -140,32 +140,43 @@ Tensor& opt_add_out(
140140
out.numel());
141141
});
142142
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
static constexpr const char op_name[] = "add.out";
144-
if (selected_optimized_path ==
145-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
146-
selected_optimized_path ==
147-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
148-
selected_optimized_path ==
149-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
150-
// Reason we swap out args here is because handle_broadcast_elementwise
151-
// handles this selected_optimized_path option a bit differently.
152-
// This should really be resolved in handle_broadcast_elementwise.
153-
// However, the current blocker is that handle_broadcast_elementwise tries
154-
// to be agnostic of op. This should be fixed, likely by moving lambda
155-
// creation to handle_broadcast_elementwise and it be aware of which op is
156-
// being executed.
157-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
158-
return y + alpha_val * x;
159-
};
160-
return torch::executor::handle_broadcast_elementwise<op_name>(
161-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
162-
} else {
163-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
164-
return x + alpha_val * y;
165-
};
166-
return torch::executor::handle_broadcast_elementwise<op_name>(
167-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
168-
}
143+
ScalarType out_type = out.scalar_type();
144+
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
145+
CTYPE alpha_val;
146+
ET_KERNEL_CHECK_MSG(
147+
ctx,
148+
utils::extract_scalar(alpha, &alpha_val),
149+
InvalidArgument,
150+
out,
151+
"Failed to extract scalar alpha.");
152+
using Vec = executorch::vec::Vectorized<CTYPE>;
153+
Vec alpha_val_vec(alpha_val);
154+
if (selected_optimized_path ==
155+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
156+
selected_optimized_path ==
157+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
158+
selected_optimized_path ==
159+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
160+
// Reason we swap out args here is because handle_broadcast_elementwise
161+
// handles this selected_optimized_path option a bit differently.
162+
// This should really be resolved in handle_broadcast_elementwise.
163+
// However, the current blocker is that handle_broadcast_elementwise
164+
// tries to be agnostic of op. This should be fixed, likely by moving
165+
// lambda creation to handle_broadcast_elementwise and it be aware of
166+
// which op is being executed.
167+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
168+
return y + alpha_val_vec * x;
169+
};
170+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
171+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
172+
} else {
173+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
174+
return x + alpha_val_vec * y;
175+
};
176+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
177+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
178+
}
179+
});
169180
} else {
170181
ScalarType common_type =
171182
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/op_mul.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,12 @@ Tensor& opt_mul_out(
130130
out.numel());
131131
});
132132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
// Reason for using alpha even when used for mul is becasuse
134-
// handle_broadcast_elementwise is used for add and sub as well
135-
// and it uses alpha.
136-
auto mul_lambda = [](auto x, auto y, [[maybe_unused]] auto alpha) {
137-
return x * y;
138-
};
139-
static constexpr const char op_name[] = "mul.out";
140-
return torch::executor::handle_broadcast_elementwise<op_name>(
141-
ctx, mul_lambda, a, b, out, selected_optimized_path);
133+
ScalarType out_type = out.scalar_type();
134+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
135+
auto mul_lambda = [](auto x, auto y) { return x * y; };
136+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
137+
ctx, mul_lambda, a, b, out, selected_optimized_path);
138+
});
142139
} else {
143140
ScalarType common_type =
144141
promoteTypes(a_type, b_type, /*half_to_float*/ true);

0 commit comments

Comments
 (0)