Skip to content

Commit bf761db

Browse files
committed
Update on "[ExecuTorch] Add broadcast support for optimized add op"
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent dbe3e8a commit bf761db

File tree

4 files changed

+65
-12
lines changed

4 files changed

+65
-12
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,38 @@ enum class ElementwiseOptimizedPath {
4949
kBroadcastLastDimReverseArguments,
5050
};
5151

52+
enum class BinaryOpType {
53+
kAdd,
54+
kSub,
55+
kMul,
56+
kDiv,
57+
};
58+
5259
namespace internal {
5360

61+
template <BinaryOpType op_type>
62+
struct BinaryOpTypeName;
63+
64+
template <>
65+
struct BinaryOpTypeName<BinaryOpType::kAdd> {
66+
static constexpr char kName[] = "add.out";
67+
};
68+
69+
template <>
70+
struct BinaryOpTypeName<BinaryOpType::kSub> {
71+
static constexpr char kName[] = "sub.out";
72+
};
73+
74+
template <>
75+
struct BinaryOpTypeName<BinaryOpType::kMul> {
76+
static constexpr char kName[] = "mul.out";
77+
};
78+
79+
template <>
80+
struct BinaryOpTypeName<BinaryOpType::kDiv> {
81+
static constexpr char kName[] = "div.out";
82+
};
83+
5484
/*
5585
Given two tensors, this function returns the broadcast dim if it exists.
5686
Returns 0 if no broadcast dim is found.
@@ -192,7 +222,7 @@ std::array<int32_t, 3> inline get_normalized_tensor_size(
192222
return normalized_tensor_size;
193223
}
194224

195-
template <typename Op>
225+
template <BinaryOpType op_type, typename Op>
196226
Tensor& handle_last_dim_broadcast_elementwise(
197227
KernelRuntimeContext& ctx,
198228
const Op& vec_fun,
@@ -221,7 +251,7 @@ Tensor& handle_last_dim_broadcast_elementwise(
221251
"Failed to resize output tensor.");
222252
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
223253
const auto broadcast_size = out.size(out.dim() - 1);
224-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
254+
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
225255
using Vec = executorch::vec::Vectorized<CTYPE>;
226256
CTYPE alpha_val;
227257
Vec alpha_val_vec(alpha_val);
@@ -246,7 +276,7 @@ Tensor& handle_last_dim_broadcast_elementwise(
246276
return out;
247277
}
248278

249-
template <typename Op>
279+
template <BinaryOpType op_type, typename Op>
250280
Tensor& handle_broadcast_elementwise(
251281
KernelRuntimeContext& ctx,
252282
const Op& vec_fun,
@@ -259,7 +289,7 @@ Tensor& handle_broadcast_elementwise(
259289
ElementwiseOptimizedPath::kBroadcastLastDim) ||
260290
(selected_optimized_path ==
261291
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
262-
return handle_last_dim_broadcast_elementwise(
292+
return handle_last_dim_broadcast_elementwise<op_type>(
263293
ctx, vec_fun, a, b, out, selected_optimized_path, alpha);
264294
}
265295

@@ -306,7 +336,7 @@ Tensor& handle_broadcast_elementwise(
306336
broadcast_size = lhs->sizes()[lhs->dim() - 2];
307337
inner_size = lhs->sizes()[lhs->dim() - 1];
308338
}
309-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
339+
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() {
310340
using Vec = executorch::vec::Vectorized<CTYPE>;
311341
CTYPE alpha_val;
312342
Vec alpha_val_vec;

kernels/optimized/cpu/op_add.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,31 @@ Tensor& opt_add_out(
140140
out.numel());
141141
});
142142
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
144-
return x + alpha_val * y;
145-
};
146-
return torch::executor::handle_broadcast_elementwise(
147-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
143+
if (selected_optimized_path ==
144+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
145+
selected_optimized_path ==
146+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
147+
selected_optimized_path ==
148+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
149+
// This behavior is a bit confusing.
150+
// Reason we swap out args here is because handle_broadcast_elementwise
151+
// handles this selected_optimized_path option a bit differently.
152+
// This should really be resoled in handle_broadcast_elementwise.
153+
// However, the current blocker is that handle_broadcast_elementwise tries to
154+
// be agnostic of op. This should be fixed, likely by moving lambda creation
155+
// to handle_broadcast_elementwise and it be aware of which op is being executed.
156+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
157+
return y + alpha_val * x;
158+
};
159+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
160+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
161+
} else {
162+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
163+
return x + alpha_val * y;
164+
};
165+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
166+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
167+
}
148168
} else {
149169
ScalarType common_type =
150170
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/op_mul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ Tensor& opt_mul_out(
135135
(void)alpha;
136136
return x * y;
137137
};
138-
return torch::executor::handle_broadcast_elementwise(
138+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kMul>(
139139
ctx, mul_lambda, a, b, out, selected_optimized_path);
140140
} else {
141141
ScalarType common_type =

kernels/test/op_add_test.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ class OpAddOutKernelTest : public OperatorTest {
129129

130130
// Check that it matches the expected output.
131131
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
132-
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
132+
expected = tf_a.make(
133+
{2, 2, 3},
134+
/*data=*/{3.5, 6, 8.5, 8, 10.5, 13, 15.5, 18, 20.5, 20, 22.5, 25});
135+
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.5, out), expected);
133136
}
134137

135138
template <ScalarType DTYPE>

0 commit comments

Comments
 (0)