Skip to content

Commit 3666bd0

Browse files
committed
Update on "[Executorch] Add broadcasting support to optimized op_sub"
Summary: This diff builds on top of previous one to add support for limited handling of broadcasting for sub Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
2 parents 32a010f + 80e39bb commit 3666bd0

File tree

3 files changed

+82
-97
lines changed

3 files changed

+82
-97
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
192192
return normalized_tensor_size;
193193
}
194194

195-
template <const char* op_name, typename Op>
195+
template <typename CTYPE, typename Op>
196196
Tensor& handle_last_dim_broadcast_elementwise(
197197
KernelRuntimeContext& ctx,
198198
const Op& vec_fun,
@@ -221,32 +221,17 @@ Tensor& handle_last_dim_broadcast_elementwise(
221221
"Failed to resize output tensor.");
222222
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
223223
const auto broadcast_size = out.size(out.dim() - 1);
224-
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
225-
using Vec = executorch::vec::Vectorized<CTYPE>;
226-
Vec alpha_val_vec;
227-
if (alpha.has_value()) {
228-
CTYPE alpha_val;
229-
ET_KERNEL_CHECK(
230-
ctx,
231-
native::utils::extract_scalar(alpha.value(), &alpha_val),
232-
InvalidArgument, );
233-
alpha_val_vec = Vec(alpha_val);
234-
}
235-
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
236-
return vec_fun(a, b, alpha_val_vec);
237-
};
238-
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
239-
vec_fun_alpha,
240-
out.mutable_data_ptr<CTYPE>(),
241-
lhs->const_data_ptr<CTYPE>(),
242-
rhs->const_data_ptr<CTYPE>(),
243-
outer_size,
244-
broadcast_size);
245-
});
224+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE, Op>(
225+
vec_fun,
226+
out.mutable_data_ptr<CTYPE>(),
227+
lhs->const_data_ptr<CTYPE>(),
228+
rhs->const_data_ptr<CTYPE>(),
229+
outer_size,
230+
broadcast_size);
246231
return out;
247232
}
248233

249-
template <const char* op_name, typename Op>
234+
template <typename CTYPE, typename Op>
250235
Tensor& handle_broadcast_elementwise(
251236
KernelRuntimeContext& ctx,
252237
const Op& vec_fun,
@@ -259,11 +244,10 @@ Tensor& handle_broadcast_elementwise(
259244
ElementwiseOptimizedPath::kBroadcastLastDim) ||
260245
(selected_optimized_path ==
261246
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
262-
return handle_last_dim_broadcast_elementwise<op_name>(
263-
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
247+
return handle_last_dim_broadcast_elementwise<CTYPE>(
248+
ctx, vec_fun, a, b, out, selected_optimized_path);
264249
}
265250

266-
ScalarType out_type = out.scalar_type();
267251
const Tensor* lhs;
268252
const Tensor* rhs;
269253
if ((selected_optimized_path ==
@@ -306,30 +290,14 @@ Tensor& handle_broadcast_elementwise(
306290
broadcast_size = lhs->sizes()[lhs->dim() - 2];
307291
inner_size = lhs->sizes()[lhs->dim() - 1];
308292
}
309-
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
310-
using Vec = executorch::vec::Vectorized<CTYPE>;
311-
Vec alpha_val_vec;
312-
if (alpha.has_value()) {
313-
CTYPE alpha_val;
314-
ET_KERNEL_CHECK(
315-
ctx,
316-
native::utils::extract_scalar(alpha.value(), &alpha_val),
317-
InvalidArgument, );
318-
alpha_val_vec = Vec(alpha_val);
319-
}
320-
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) {
321-
return vec_fun(a, b, alpha_val_vec);
322-
};
323-
executorch::vec::
324-
broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>(
325-
vec_fun_alpha,
326-
out.mutable_data_ptr<CTYPE>(),
327-
lhs->const_data_ptr<CTYPE>(),
328-
rhs->const_data_ptr<CTYPE>(),
329-
outer_size,
330-
broadcast_size,
331-
inner_size);
332-
});
293+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE, Op>(
294+
vec_fun,
295+
out.mutable_data_ptr<CTYPE>(),
296+
lhs->const_data_ptr<CTYPE>(),
297+
rhs->const_data_ptr<CTYPE>(),
298+
outer_size,
299+
broadcast_size,
300+
inner_size);
333301
return out;
334302
}
335303
} // namespace executor

kernels/optimized/cpu/op_add_sub_impl.h

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -115,45 +115,65 @@ Tensor& opt_add_sub_out_impl(
115115
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
116116
// Cannot apply the trick of -alpha here because alpha is Scalar without
117117
// support for - operator. At least not right now.
118-
if constexpr (is_sub) {
119-
if (selected_optimized_path ==
120-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
121-
selected_optimized_path ==
122-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
123-
selected_optimized_path ==
124-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
125-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
126-
return y - alpha_val * x;
127-
};
128-
return torch::executor::handle_broadcast_elementwise<op_name>(
129-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
130-
} else {
131-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
132-
return x - alpha_val * y;
133-
};
134-
return torch::executor::handle_broadcast_elementwise<op_name>(
135-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
136-
}
137-
} else {
138-
if (selected_optimized_path ==
139-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
140-
selected_optimized_path ==
141-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
142-
selected_optimized_path ==
143-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
144-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
145-
return y + alpha_val * x;
146-
};
147-
return torch::executor::handle_broadcast_elementwise<op_name>(
148-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
118+
ScalarType out_type = out.scalar_type();
119+
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
120+
CTYPE alpha_val;
121+
ET_KERNEL_CHECK_MSG(
122+
ctx,
123+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
124+
InvalidArgument,
125+
out,
126+
"Failed to extract scalar alpha.");
127+
using Vec = executorch::vec::Vectorized<CTYPE>;
128+
Vec alpha_val_vec(alpha_val);
129+
if constexpr (is_sub) {
130+
if (selected_optimized_path ==
131+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
132+
selected_optimized_path ==
133+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
134+
selected_optimized_path ==
135+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
136+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
137+
return y - alpha_val_vec * x;
138+
};
139+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
140+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
141+
} else {
142+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
143+
return x - alpha_val_vec * y;
144+
};
145+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
146+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
147+
}
149148
} else {
150-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
151-
return x + alpha_val * y;
152-
};
153-
return torch::executor::handle_broadcast_elementwise<op_name>(
154-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
149+
if (selected_optimized_path ==
150+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
151+
selected_optimized_path ==
152+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
153+
selected_optimized_path ==
154+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
155+
// Reason we swap out args here is because
156+
// handle_broadcast_elementwise handles this selected_optimized_path
157+
// option a bit differently. This should really be resolved in
158+
// handle_broadcast_elementwise. However, the current blocker is that
159+
// handle_broadcast_elementwise tries to be agnostic of op. This
160+
// should be fixed, likely by moving lambda creation to
161+
// handle_broadcast_elementwise and it be aware of which op is being
162+
// executed.
163+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
164+
return y + alpha_val_vec * x;
165+
};
166+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
167+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
168+
} else {
169+
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
170+
return x + alpha_val_vec * y;
171+
};
172+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
173+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
174+
}
155175
}
156-
}
176+
});
157177
} else {
158178
ScalarType common_type =
159179
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/op_mul.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,12 @@ Tensor& opt_mul_out(
130130
out.numel());
131131
});
132132
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
133-
// Reason for using alpha even when used for mul is becasuse
134-
// handle_broadcast_elementwise is used for add and sub as well
135-
// and it uses alpha.
136-
auto mul_lambda = [](auto x, auto y, [[maybe_unused]] auto alpha) {
137-
return x * y;
138-
};
139-
static constexpr const char op_name[] = "mul.out";
140-
return torch::executor::handle_broadcast_elementwise<op_name>(
141-
ctx, mul_lambda, a, b, out, selected_optimized_path);
133+
ScalarType out_type = out.scalar_type();
134+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
135+
auto mul_lambda = [](auto x, auto y) { return x * y; };
136+
return torch::executor::handle_broadcast_elementwise<CTYPE>(
137+
ctx, mul_lambda, a, b, out, selected_optimized_path);
138+
});
142139
} else {
143140
ScalarType common_type =
144141
promoteTypes(a_type, b_type, /*half_to_float*/ true);

0 commit comments

Comments
 (0)