Skip to content

Commit 33f75f2

Browse files
committed
[Executorch][Optimized] Use portable's impl for optimized op_add's fallback
This wy we can get advantage of size and build opt efforts in portable kernels Differential Revision: [D65632375](https://our.internmc.facebook.com/intern/diff/D65632375/) [ghstack-poisoned]
1 parent c2183fb commit 33f75f2

File tree

2 files changed

+4
-53
lines changed

2 files changed

+4
-53
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1010
#include <executorch/kernels/optimized/vec/functional.h>
1111
#include <executorch/kernels/optimized/vec/vec.h>
12+
#include <executorch/kernels/portable/cpu/op_add_impl.h>
1213
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -176,35 +177,7 @@ Tensor& opt_add_out(
176177
lhs->sizes()[lhs->dim() - 1]);
177178
});
178179
} else {
179-
ScalarType common_type =
180-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
181-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
182-
183-
ET_KERNEL_CHECK(
184-
ctx,
185-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
186-
InvalidArgument,
187-
out);
188-
189-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
190-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
191-
using CTYPE_IN = typename torch::executor::
192-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
193-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
194-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
195-
CTYPE_IN alpha_val;
196-
ET_KERNEL_CHECK(
197-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
198-
199-
AddInner<
200-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
201-
CTYPE_A,
202-
CTYPE_B,
203-
CTYPE_IN,
204-
CTYPE_OUT>::run(a, b, alpha_val, out);
205-
});
206-
});
207-
});
180+
add_out_impl(ctx, a, b, alpha, out);
208181
}
209182

210183
return out;
@@ -255,30 +228,7 @@ Tensor& opt_add_scalar_out(
255228
});
256229
});
257230
} else {
258-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
259-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
260-
ET_SWITCH_REALB_TYPES(
261-
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
262-
ET_SWITCH_REALHBBF16_TYPES(
263-
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
264-
CTYPE_B b_val;
265-
ET_EXTRACT_SCALAR(b, b_val);
266-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
267-
CTYPE_IN alpha_val;
268-
ET_EXTRACT_SCALAR(alpha, alpha_val);
269-
270-
const size_t n = a.numel();
271-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
272-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
273-
for (auto i = 0; i < n; ++i) {
274-
out_data[i] = static_cast<CTYPE_OUT>(
275-
static_cast<CTYPE_IN>(a_data[i]) +
276-
alpha_val * b_casted);
277-
}
278-
});
279-
});
280-
});
281-
});
231+
add_scalar_out_impl(ctx, a, b, alpha, out);
282232
}
283233

284234
return out;

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ _OPTIMIZED_ATEN_OPS = (
66
name = "op_add",
77
deps = [
88
":binary_ops",
9+
"//executorch/kernels/portable/cpu:op_add_impl",
910
"//executorch/kernels/portable/cpu:scalar_utils",
1011
"//executorch/kernels/portable/cpu/util:broadcast_util",
1112
],

0 commit comments

Comments
 (0)