Skip to content

Commit 5da628e

Browse files
committed
[Exeuctorch][Portale] Refactor op_sub to be reused by optimized fallback
If we can use portable for optimized's fallback then we can remove copy pasted stuff and take advantage of size and build reduction efforts in portable. Differential Revision: [D65666033](https://our.internmc.facebook.com/intern/diff/D65666033/) [ghstack-poisoned]
1 parent 33f75f2 commit 5da628e

File tree

5 files changed

+192
-98
lines changed

5 files changed

+192
-98
lines changed

kernels/portable/cpu/op_sub.cpp

Lines changed: 4 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212
#include <executorch/runtime/platform/assert.h>
1313

14+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
15+
1416
namespace torch {
1517
namespace executor {
1618
namespace native {
@@ -21,55 +23,7 @@ Tensor& sub_out(
2123
const Tensor& b,
2224
const Scalar& alpha,
2325
Tensor& out) {
24-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
25-
26-
// Check alpha type
27-
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
28-
29-
// Common Dtype
30-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
31-
32-
// Check Common Dtype
33-
ET_KERNEL_CHECK(
34-
ctx,
35-
(canCast(common_type, out.scalar_type()) &&
36-
canCast(alpha_type, common_type)),
37-
InvalidArgument,
38-
out);
39-
40-
// Check Dim Order
41-
ET_KERNEL_CHECK(
42-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43-
44-
// Resize
45-
ET_KERNEL_CHECK(
46-
ctx,
47-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
48-
InvalidArgument,
49-
out);
50-
51-
// Compute Dtype
52-
ScalarType compute_type = utils::get_compute_type(common_type);
53-
54-
// @lint-ignore CLANGTIDY facebook-hte-CArray
55-
static constexpr const char op_name[] = "sub.out";
56-
57-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
59-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
60-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
61-
return val_a - val_alpha * val_b;
62-
},
63-
ctx,
64-
a,
65-
utils::SupportedTensorDtypes::REALHBF16,
66-
b,
67-
utils::SupportedTensorDtypes::REALHBF16,
68-
out,
69-
utils::SupportedTensorDtypes::REALHBF16);
70-
});
71-
72-
return out;
26+
return sub_out_impl(ctx, a, b, alpha, out);
7327
}
7428

7529
Tensor& sub_scalar_out(
@@ -78,50 +32,7 @@ Tensor& sub_scalar_out(
7832
const Scalar& b,
7933
const Scalar& alpha,
8034
Tensor& out) {
81-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
82-
83-
// Check alpha type
84-
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
85-
86-
// Common Dtype
87-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
88-
89-
// Check Common Dtype
90-
ET_KERNEL_CHECK(
91-
ctx,
92-
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
93-
InvalidArgument,
94-
out);
95-
96-
// Check Dim Order
97-
ET_KERNEL_CHECK(
98-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
99-
100-
// Resize
101-
ET_KERNEL_CHECK(
102-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
103-
104-
// Compute Dtype
105-
ScalarType compute_type = utils::get_compute_type(common_type);
106-
107-
// @lint-ignore CLANGTIDY facebook-hte-CArray
108-
static constexpr const char op_name[] = "sub.Scalar_out";
109-
110-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
111-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
112-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
113-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
114-
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
115-
return val_a - val_alpha * val_b;
116-
},
117-
ctx,
118-
a,
119-
utils::SupportedTensorDtypes::REALHBF16,
120-
out,
121-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
122-
});
123-
124-
return out;
35+
return sub_scalar_out_impl(ctx, a, b, alpha, out);
12536
}
12637

12738
} // namespace native
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
Tensor& sub_out_impl(
21+
KernelRuntimeContext& ctx,
22+
const Tensor& a,
23+
const Tensor& b,
24+
const Scalar& alpha,
25+
Tensor& out) {
26+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
27+
28+
// Check alpha type
29+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
30+
31+
// Common Dtype
32+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
33+
34+
// Check Common Dtype
35+
ET_KERNEL_CHECK(
36+
ctx,
37+
(canCast(common_type, out.scalar_type()) &&
38+
canCast(alpha_type, common_type)),
39+
InvalidArgument,
40+
out);
41+
42+
// Check Dim Order
43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
45+
46+
// Resize
47+
ET_KERNEL_CHECK(
48+
ctx,
49+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
50+
InvalidArgument,
51+
out);
52+
53+
// Compute Dtype
54+
ScalarType compute_type = utils::get_compute_type(common_type);
55+
56+
// @lint-ignore CLANGTIDY facebook-hte-CArray
57+
static constexpr const char op_name[] = "sub.out";
58+
59+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
60+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
61+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63+
return val_a - val_alpha * val_b;
64+
},
65+
ctx,
66+
a,
67+
utils::SupportedTensorDtypes::REALHBF16,
68+
b,
69+
utils::SupportedTensorDtypes::REALHBF16,
70+
out,
71+
utils::SupportedTensorDtypes::REALHBF16);
72+
});
73+
74+
return out;
75+
}
76+
77+
Tensor& sub_scalar_out_impl(
78+
KernelRuntimeContext& ctx,
79+
const Tensor& a,
80+
const Scalar& b,
81+
const Scalar& alpha,
82+
Tensor& out) {
83+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
84+
85+
// Check alpha type
86+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
87+
88+
// Common Dtype
89+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
90+
91+
// Check Common Dtype
92+
ET_KERNEL_CHECK(
93+
ctx,
94+
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
95+
InvalidArgument,
96+
out);
97+
98+
// Check Dim Order
99+
ET_KERNEL_CHECK(
100+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
101+
102+
// Resize
103+
ET_KERNEL_CHECK(
104+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
105+
106+
// Compute Dtype
107+
ScalarType compute_type = utils::get_compute_type(common_type);
108+
109+
// @lint-ignore CLANGTIDY facebook-hte-CArray
110+
static constexpr const char op_name[] = "sub.Scalar_out";
111+
112+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
113+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
114+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
115+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
116+
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
117+
return val_a - val_alpha * val_b;
118+
},
119+
ctx,
120+
a,
121+
utils::SupportedTensorDtypes::REALHBF16,
122+
out,
123+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
124+
});
125+
126+
return out;
127+
}
128+
129+
} // namespace native
130+
} // namespace executor
131+
} // namespace torch

kernels/portable/cpu/op_sub_impl.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
Tensor& sub_out_impl(
19+
KernelRuntimeContext& ctx,
20+
const Tensor& a,
21+
const Tensor& b,
22+
const Scalar& alpha,
23+
Tensor& out);
24+
25+
Tensor& sub_scalar_out_impl(
26+
KernelRuntimeContext& ctx,
27+
const Tensor& a,
28+
const Scalar& b,
29+
const Scalar& alpha,
30+
Tensor& out);
31+
32+
} // namespace native
33+
} // namespace executor
34+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,26 @@ def define_common_targets():
151151
],
152152
)
153153

154+
runtime.cxx_library(
155+
name = "op_sub_impl",
156+
srcs = ["op_sub_impl.cpp"],
157+
exported_headers = ["op_sub_impl.h"],
158+
visibility = [
159+
"//executorch/kernels/portable/cpu/...",
160+
"//executorch/kernels/optimized/cpu/...",
161+
"//executorch/kernels/portable/test/...",
162+
"@EXECUTORCH_CLIENTS",
163+
],
164+
exported_deps = [
165+
"//executorch/kernels/portable/cpu/util:broadcast_util",
166+
"//executorch/kernels/portable/cpu/util:dtype_util",
167+
"//executorch/kernels/portable/cpu/util:elementwise_util",
168+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
169+
"//executorch/kernels/portable/cpu:scalar_utils",
170+
"//executorch/runtime/kernel:kernel_includes",
171+
],
172+
)
173+
154174
# The following will not participate in dtype selective build because
155175
# they are refactored such to be used in optimized op implementations as well
156176
# and we have not enabled selective build for optimized ops.
@@ -165,6 +185,7 @@ def define_common_targets():
165185
"//executorch/kernels/portable/cpu:op_add_impl",
166186
"//executorch/kernels/portable/cpu:op_div_impl",
167187
"//executorch/kernels/portable/cpu:op_mul_impl",
188+
"//executorch/kernels/portable/cpu:op_sub_impl",
168189
],
169190
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
170191
)

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,10 +1138,7 @@ ATEN_OPS = (
11381138
op_target(
11391139
name = "op_sub",
11401140
deps = [
1141-
":scalar_utils",
1142-
"//executorch/kernels/portable/cpu/util:broadcast_util",
1143-
"//executorch/kernels/portable/cpu/util:dtype_util",
1144-
"//executorch/kernels/portable/cpu/util:elementwise_util",
1141+
"//executorch/kernels/portable/cpu:op_sub_impl",
11451142
],
11461143
),
11471144
op_target(
@@ -1264,4 +1261,4 @@ def portable_source_list():
12641261

12651262
def portable_header_list():
12661263
"""All the header file names from //executorch/kernels/portable/cpu/"""
1267-
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"]
1264+
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h", "op_sub_impl.h"]

0 commit comments

Comments
 (0)