Skip to content

Commit 507f906

Browse files
committed
[Exeuctorch][Portale] Refactor op_add to be reused by optimized fallback
Pull Request resolved: #6734 If we can use portable for optimized's fallback then we can remove copy pasted stuff and take advantage of size and build reduction efforts in portable. ghstack-source-id: 252584477 @exported-using-ghexport Differential Revision: [D65632373](https://our.internmc.facebook.com/intern/diff/D65632373/)
1 parent 5fe76be commit 507f906

File tree

5 files changed

+181
-94
lines changed

5 files changed

+181
-94
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 4 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11-
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
129
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
10+
11+
#include <executorch/kernels/portable/cpu/op_add_impl.h>
1412

1513
namespace torch {
1614
namespace executor {
@@ -22,50 +20,7 @@ Tensor& add_out(
2220
const Tensor& b,
2321
const Scalar& alpha,
2422
Tensor& out) {
25-
// Common Dtype
26-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27-
28-
// Check Common Dtype
29-
ET_KERNEL_CHECK(
30-
ctx,
31-
(canCast(common_type, out.scalar_type()) &&
32-
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
33-
InvalidArgument,
34-
out);
35-
36-
// Check Dim Order
37-
ET_KERNEL_CHECK(
38-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39-
40-
// Resize
41-
ET_KERNEL_CHECK(
42-
ctx,
43-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44-
InvalidArgument,
45-
out);
46-
47-
// Compute Dtype
48-
ScalarType compute_type = utils::get_compute_type(common_type);
49-
50-
// @lint-ignore CLANGTIDY facebook-hte-CArray
51-
static constexpr const char op_name[] = "add.out";
52-
53-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57-
return val_a + val_alpha * val_b;
58-
},
59-
ctx,
60-
a,
61-
utils::SupportedTensorDtypes::REALHBBF16,
62-
b,
63-
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
66-
});
67-
68-
return out;
23+
return add_out_impl(ctx, a, b, alpha, out);
6924
}
7025

7126
Tensor& add_scalar_out(
@@ -74,46 +29,7 @@ Tensor& add_scalar_out(
7429
const Scalar& b,
7530
const Scalar& alpha,
7631
Tensor& out) {
77-
// Common Dtype
78-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
79-
80-
// Check Common Dtype
81-
ET_KERNEL_CHECK(
82-
ctx,
83-
(common_type == out.scalar_type() &&
84-
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
85-
InvalidArgument,
86-
out);
87-
88-
// Check Dim Order
89-
ET_KERNEL_CHECK(
90-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
91-
92-
// Resize
93-
ET_KERNEL_CHECK(
94-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
95-
96-
// Compute Dtype
97-
ScalarType compute_type = utils::get_compute_type(common_type);
98-
99-
// @lint-ignore CLANGTIDY facebook-hte-CArray
100-
static constexpr const char op_name[] = "add.Scalar_out";
101-
102-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
103-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
104-
[b, alpha](const CTYPE_COMPUTE val_a) {
105-
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106-
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
107-
return val_a + val_alpha * val_b;
108-
},
109-
ctx,
110-
a,
111-
utils::SupportedTensorDtypes::REALHBBF16,
112-
out,
113-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
114-
});
115-
116-
return out;
32+
return add_scalar_out_impl(ctx, a, b, alpha, out);
11733
}
11834

11935
} // namespace native
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
#include <executorch/runtime/platform/assert.h>
14+
15+
#include <executorch/kernels/portable/cpu/op_add_impl.h>
16+
17+
namespace torch {
18+
namespace executor {
19+
namespace native {
20+
21+
Tensor& add_out_impl(
22+
KernelRuntimeContext& ctx,
23+
const Tensor& a,
24+
const Tensor& b,
25+
const Scalar& alpha,
26+
Tensor& out) {
27+
// Common Dtype
28+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
29+
30+
// Check Common Dtype
31+
ET_KERNEL_CHECK(
32+
ctx,
33+
(canCast(common_type, out.scalar_type()) &&
34+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
35+
InvalidArgument,
36+
out);
37+
38+
// Check Dim Order
39+
ET_KERNEL_CHECK(
40+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
41+
42+
// Resize
43+
ET_KERNEL_CHECK(
44+
ctx,
45+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
46+
InvalidArgument,
47+
out);
48+
49+
// Compute Dtype
50+
ScalarType compute_type = utils::get_compute_type(common_type);
51+
52+
// @lint-ignore CLANGTIDY facebook-hte-CArray
53+
static constexpr const char op_name[] = "add.out";
54+
55+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
57+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
58+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59+
return val_a + val_alpha * val_b;
60+
},
61+
ctx,
62+
a,
63+
utils::SupportedTensorDtypes::REALHBBF16,
64+
b,
65+
utils::SupportedTensorDtypes::REALHBBF16,
66+
out,
67+
utils::SupportedTensorDtypes::REALHBBF16);
68+
});
69+
70+
return out;
71+
}
72+
73+
Tensor& add_scalar_out_impl(
74+
KernelRuntimeContext& ctx,
75+
const Tensor& a,
76+
const Scalar& b,
77+
const Scalar& alpha,
78+
Tensor& out) {
79+
// Common Dtype
80+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
81+
82+
// Check Common Dtype
83+
ET_KERNEL_CHECK(
84+
ctx,
85+
(common_type == out.scalar_type() &&
86+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
87+
InvalidArgument,
88+
out);
89+
90+
// Check Dim Order
91+
ET_KERNEL_CHECK(
92+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
93+
94+
// Resize
95+
ET_KERNEL_CHECK(
96+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
97+
98+
// Compute Dtype
99+
ScalarType compute_type = utils::get_compute_type(common_type);
100+
101+
// @lint-ignore CLANGTIDY facebook-hte-CArray
102+
static constexpr const char op_name[] = "add.Scalar_out";
103+
104+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
106+
[b, alpha](const CTYPE_COMPUTE val_a) {
107+
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
108+
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
109+
return val_a + val_alpha * val_b;
110+
},
111+
ctx,
112+
a,
113+
utils::SupportedTensorDtypes::REALHBBF16,
114+
out,
115+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
116+
});
117+
118+
return out;
119+
}
120+
121+
} // namespace native
122+
} // namespace executor
123+
} // namespace torch

kernels/portable/cpu/op_add_impl.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
15+
Tensor& add_out_impl(
16+
KernelRuntimeContext& ctx,
17+
const Tensor& a,
18+
const Tensor& b,
19+
const Scalar& alpha,
20+
Tensor& out);
21+
22+
Tensor& add_scalar_out_impl(
23+
KernelRuntimeContext& ctx,
24+
const Tensor& a,
25+
const Scalar& b,
26+
const Scalar& alpha,
27+
Tensor& out);
28+
29+
} // namespace native
30+
} // namespace executor
31+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,26 @@ def define_common_targets():
131131
],
132132
)
133133

134+
runtime.cxx_library(
135+
name = "op_add_impl",
136+
srcs = ["op_add_impl.cpp"],
137+
exported_headers = ["op_add_impl.h"],
138+
visibility = [
139+
"//executorch/kernels/portable/cpu/...",
140+
"//executorch/kernels/optimized/cpu/...",
141+
"//executorch/kernels/portable/test/...",
142+
"@EXECUTORCH_CLIENTS",
143+
],
144+
exported_deps = [
145+
"//executorch/kernels/portable/cpu/util:broadcast_util",
146+
"//executorch/kernels/portable/cpu/util:dtype_util",
147+
"//executorch/kernels/portable/cpu/util:elementwise_util",
148+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
149+
"//executorch/kernels/portable/cpu:scalar_utils",
150+
"//executorch/runtime/kernel:kernel_includes",
151+
],
152+
)
153+
134154
# The following will not participate in dtype selective build because
135155
# they are refactored such to be used in optimized op implementations as well
136156
# and we have not enabled selective build for optimized ops.
@@ -142,6 +162,7 @@ def define_common_targets():
142162
runtime.cxx_library(
143163
name = "all_impl_deps",
144164
deps = [
165+
"//executorch/kernels/portable/cpu:op_add_impl",
145166
"//executorch/kernels/portable/cpu:op_div_impl",
146167
"//executorch/kernels/portable/cpu:op_mul_impl",
147168
],

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,7 @@ ATEN_OPS = (
220220
op_target(
221221
name = "op_add",
222222
deps = [
223-
"//executorch/kernels/portable/cpu/util:broadcast_util",
224-
"//executorch/kernels/portable/cpu/util:dtype_util",
225-
"//executorch/kernels/portable/cpu/util:elementwise_util",
226-
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
227-
":scalar_utils",
223+
"//executorch/kernels/portable/cpu:op_add_impl",
228224
],
229225
),
230226
op_target(
@@ -1268,4 +1264,4 @@ def portable_source_list():
12681264

12691265
def portable_header_list():
12701266
"""All the header file names from //executorch/kernels/portable/cpu/"""
1271-
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"]
1267+
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"]

0 commit comments

Comments
 (0)