Skip to content

Commit 43d4564

Browse files
committed
[Exeuctorch][Portale] Refactor op_add to be reused by optimized fallback
If we can use portable for optimized's fallback then we can remove copy pasted stuff and take advantage of size and build reduction efforts in portable. Differential Revision: [D65632373](https://our.internmc.facebook.com/intern/diff/D65632373/) ghstack-source-id: 252472693 Pull Request resolved: #6734
1 parent 2f2c739 commit 43d4564

File tree

5 files changed

+181
-94
lines changed

5 files changed

+181
-94
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 4 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11-
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
129
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
10+
11+
#include <executorch/kernels/portable/cpu/op_add_impl.h>
1412

1513
namespace torch {
1614
namespace executor {
@@ -22,50 +20,7 @@ Tensor& add_out(
2220
const Tensor& b,
2321
const Scalar& alpha,
2422
Tensor& out) {
25-
// Common Dtype
26-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27-
28-
// Check Common Dtype
29-
ET_KERNEL_CHECK(
30-
ctx,
31-
(canCast(common_type, out.scalar_type()) &&
32-
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
33-
InvalidArgument,
34-
out);
35-
36-
// Check Dim Order
37-
ET_KERNEL_CHECK(
38-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39-
40-
// Resize
41-
ET_KERNEL_CHECK(
42-
ctx,
43-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44-
InvalidArgument,
45-
out);
46-
47-
// Compute Dtype
48-
ScalarType compute_type = utils::get_compute_type(common_type);
49-
50-
// @lint-ignore CLANGTIDY facebook-hte-CArray
51-
static constexpr const char op_name[] = "add.out";
52-
53-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57-
return val_a + val_alpha * val_b;
58-
},
59-
ctx,
60-
a,
61-
utils::SupportedTensorDtypes::REALHBBF16,
62-
b,
63-
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
66-
});
67-
68-
return out;
23+
return add_out_impl(ctx, a, b, alpha, out);
6924
}
7025

7126
Tensor& add_scalar_out(
@@ -74,46 +29,7 @@ Tensor& add_scalar_out(
7429
const Scalar& b,
7530
const Scalar& alpha,
7631
Tensor& out) {
77-
// Common Dtype
78-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
79-
80-
// Check Common Dtype
81-
ET_KERNEL_CHECK(
82-
ctx,
83-
(common_type == out.scalar_type() &&
84-
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
85-
InvalidArgument,
86-
out);
87-
88-
// Check Dim Order
89-
ET_KERNEL_CHECK(
90-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
91-
92-
// Resize
93-
ET_KERNEL_CHECK(
94-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
95-
96-
// Compute Dtype
97-
ScalarType compute_type = utils::get_compute_type(common_type);
98-
99-
// @lint-ignore CLANGTIDY facebook-hte-CArray
100-
static constexpr const char op_name[] = "add.Scalar_out";
101-
102-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
103-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
104-
[b, alpha](const CTYPE_COMPUTE val_a) {
105-
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106-
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
107-
return val_a + val_alpha * val_b;
108-
},
109-
ctx,
110-
a,
111-
utils::SupportedTensorDtypes::REALHBBF16,
112-
out,
113-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
114-
});
115-
116-
return out;
32+
return add_scalar_out_impl(ctx, a, b, alpha, out);
11733
}
11834

11935
} // namespace native
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
#include <executorch/runtime/platform/assert.h>
14+
15+
#include <executorch/kernels/portable/cpu/op_add_impl.h>
16+
17+
namespace torch {
18+
namespace executor {
19+
namespace native {
20+
21+
Tensor& add_out_impl(
22+
KernelRuntimeContext& ctx,
23+
const Tensor& a,
24+
const Tensor& b,
25+
const Scalar& alpha,
26+
Tensor& out) {
27+
// Common Dtype
28+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
29+
30+
// Check Common Dtype
31+
ET_KERNEL_CHECK(
32+
ctx,
33+
(canCast(common_type, out.scalar_type()) &&
34+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
35+
InvalidArgument,
36+
out);
37+
38+
// Check Dim Order
39+
ET_KERNEL_CHECK(
40+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
41+
42+
// Resize
43+
ET_KERNEL_CHECK(
44+
ctx,
45+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
46+
InvalidArgument,
47+
out);
48+
49+
// Compute Dtype
50+
ScalarType compute_type = utils::get_compute_type(common_type);
51+
52+
// @lint-ignore CLANGTIDY facebook-hte-CArray
53+
static constexpr const char op_name[] = "add.out";
54+
55+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
57+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
58+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59+
return val_a + val_alpha * val_b;
60+
},
61+
ctx,
62+
a,
63+
utils::SupportedTensorDtypes::REALHBBF16,
64+
b,
65+
utils::SupportedTensorDtypes::REALHBBF16,
66+
out,
67+
utils::SupportedTensorDtypes::REALHBBF16);
68+
});
69+
70+
return out;
71+
}
72+
73+
Tensor& add_scalar_out_impl(
74+
KernelRuntimeContext& ctx,
75+
const Tensor& a,
76+
const Scalar& b,
77+
const Scalar& alpha,
78+
Tensor& out) {
79+
// Common Dtype
80+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
81+
82+
// Check Common Dtype
83+
ET_KERNEL_CHECK(
84+
ctx,
85+
(common_type == out.scalar_type() &&
86+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
87+
InvalidArgument,
88+
out);
89+
90+
// Check Dim Order
91+
ET_KERNEL_CHECK(
92+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
93+
94+
// Resize
95+
ET_KERNEL_CHECK(
96+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
97+
98+
// Compute Dtype
99+
ScalarType compute_type = utils::get_compute_type(common_type);
100+
101+
// @lint-ignore CLANGTIDY facebook-hte-CArray
102+
static constexpr const char op_name[] = "add.Scalar_out";
103+
104+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
106+
[b, alpha](const CTYPE_COMPUTE val_a) {
107+
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
108+
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
109+
return val_a + val_alpha * val_b;
110+
},
111+
ctx,
112+
a,
113+
utils::SupportedTensorDtypes::REALHBBF16,
114+
out,
115+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
116+
});
117+
118+
return out;
119+
}
120+
121+
} // namespace native
122+
} // namespace executor
123+
} // namespace torch

kernels/portable/cpu/op_add_impl.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
15+
Tensor& add_out_impl(
16+
KernelRuntimeContext& ctx,
17+
const Tensor& a,
18+
const Tensor& b,
19+
const Scalar& alpha,
20+
Tensor& out);
21+
22+
Tensor& add_scalar_out_impl(
23+
KernelRuntimeContext& ctx,
24+
const Tensor& a,
25+
const Scalar& b,
26+
const Scalar& alpha,
27+
Tensor& out);
28+
29+
} // namespace native
30+
} // namespace executor
31+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,26 @@ def define_common_targets():
131131
],
132132
)
133133

134+
runtime.cxx_library(
135+
name = "op_add_impl",
136+
srcs = ["op_add_impl.cpp"],
137+
exported_headers = ["op_add_impl.h"],
138+
visibility = [
139+
"//executorch/kernels/portable/cpu/...",
140+
"//executorch/kernels/optimized/cpu/...",
141+
"//executorch/kernels/portable/test/...",
142+
"@EXECUTORCH_CLIENTS",
143+
],
144+
exported_deps = [
145+
"//executorch/kernels/portable/cpu/util:broadcast_util",
146+
"//executorch/kernels/portable/cpu/util:dtype_util",
147+
"//executorch/kernels/portable/cpu/util:elementwise_util",
148+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
149+
"//executorch/kernels/portable/cpu:scalar_utils",
150+
"//executorch/runtime/kernel:kernel_includes",
151+
],
152+
)
153+
134154
# The following will not participate in dtype selective build because
135155
# they are refactored such to be used in optimized op implementations as well
136156
# and we have not enabled selective build for optimized ops.
@@ -142,6 +162,7 @@ def define_common_targets():
142162
runtime.cxx_library(
143163
name = "all_impl_deps",
144164
deps = [
165+
"//executorch/kernels/portable/cpu:op_add_impl",
145166
"//executorch/kernels/portable/cpu:op_div_impl",
146167
"//executorch/kernels/portable/cpu:op_mul_impl",
147168
],

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,7 @@ ATEN_OPS = (
220220
op_target(
221221
name = "op_add",
222222
deps = [
223-
"//executorch/kernels/portable/cpu/util:broadcast_util",
224-
"//executorch/kernels/portable/cpu/util:dtype_util",
225-
"//executorch/kernels/portable/cpu/util:elementwise_util",
226-
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
227-
":scalar_utils",
223+
"//executorch/kernels/portable/cpu:op_add_impl",
228224
],
229225
),
230226
op_target(
@@ -1268,4 +1264,4 @@ def portable_source_list():
12681264

12691265
def portable_header_list():
12701266
"""All the header file names from //executorch/kernels/portable/cpu/"""
1271-
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"]
1267+
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"]

0 commit comments

Comments
 (0)