Skip to content

Commit 85ec465

Browse files
committed
[Exeuctorch][Portale] Refactor op_sub to be reused by optimized fallback
Pull Request resolved: pytorch/executorch#6736 If we can use portable for optimized's fallback then we can remove copy pasted stuff and take advantage of size and build reduction efforts in portable. ghstack-source-id: 252584482 @exported-using-ghexport Differential Revision: [D65666033](https://our.internmc.facebook.com/intern/diff/D65666033/)
1 parent c361400 commit 85ec465

File tree

5 files changed

+192
-98
lines changed

5 files changed

+192
-98
lines changed

kernels/portable/cpu/op_sub.cpp

Lines changed: 4 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212
#include <executorch/runtime/platform/assert.h>
1313

14+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
15+
1416
namespace torch {
1517
namespace executor {
1618
namespace native {
@@ -21,55 +23,7 @@ Tensor& sub_out(
2123
const Tensor& b,
2224
const Scalar& alpha,
2325
Tensor& out) {
24-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
25-
26-
// Check alpha type
27-
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
28-
29-
// Common Dtype
30-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
31-
32-
// Check Common Dtype
33-
ET_KERNEL_CHECK(
34-
ctx,
35-
(canCast(common_type, out.scalar_type()) &&
36-
canCast(alpha_type, common_type)),
37-
InvalidArgument,
38-
out);
39-
40-
// Check Dim Order
41-
ET_KERNEL_CHECK(
42-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43-
44-
// Resize
45-
ET_KERNEL_CHECK(
46-
ctx,
47-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
48-
InvalidArgument,
49-
out);
50-
51-
// Compute Dtype
52-
ScalarType compute_type = utils::get_compute_type(common_type);
53-
54-
// @lint-ignore CLANGTIDY facebook-hte-CArray
55-
static constexpr const char op_name[] = "sub.out";
56-
57-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
59-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
60-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
61-
return val_a - val_alpha * val_b;
62-
},
63-
ctx,
64-
a,
65-
utils::SupportedTensorDtypes::REALHBF16,
66-
b,
67-
utils::SupportedTensorDtypes::REALHBF16,
68-
out,
69-
utils::SupportedTensorDtypes::REALHBF16);
70-
});
71-
72-
return out;
26+
return sub_out_impl(ctx, a, b, alpha, out);
7327
}
7428

7529
Tensor& sub_scalar_out(
@@ -78,50 +32,7 @@ Tensor& sub_scalar_out(
7832
const Scalar& b,
7933
const Scalar& alpha,
8034
Tensor& out) {
81-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
82-
83-
// Check alpha type
84-
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
85-
86-
// Common Dtype
87-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
88-
89-
// Check Common Dtype
90-
ET_KERNEL_CHECK(
91-
ctx,
92-
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
93-
InvalidArgument,
94-
out);
95-
96-
// Check Dim Order
97-
ET_KERNEL_CHECK(
98-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
99-
100-
// Resize
101-
ET_KERNEL_CHECK(
102-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
103-
104-
// Compute Dtype
105-
ScalarType compute_type = utils::get_compute_type(common_type);
106-
107-
// @lint-ignore CLANGTIDY facebook-hte-CArray
108-
static constexpr const char op_name[] = "sub.Scalar_out";
109-
110-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
111-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
112-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
113-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
114-
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
115-
return val_a - val_alpha * val_b;
116-
},
117-
ctx,
118-
a,
119-
utils::SupportedTensorDtypes::REALHBF16,
120-
out,
121-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
122-
});
123-
124-
return out;
35+
return sub_scalar_out_impl(ctx, a, b, alpha, out);
12536
}
12637

12738
} // namespace native
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
Tensor& sub_out_impl(
21+
KernelRuntimeContext& ctx,
22+
const Tensor& a,
23+
const Tensor& b,
24+
const Scalar& alpha,
25+
Tensor& out) {
26+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
27+
28+
// Check alpha type
29+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
30+
31+
// Common Dtype
32+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
33+
34+
// Check Common Dtype
35+
ET_KERNEL_CHECK(
36+
ctx,
37+
(canCast(common_type, out.scalar_type()) &&
38+
canCast(alpha_type, common_type)),
39+
InvalidArgument,
40+
out);
41+
42+
// Check Dim Order
43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
45+
46+
// Resize
47+
ET_KERNEL_CHECK(
48+
ctx,
49+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
50+
InvalidArgument,
51+
out);
52+
53+
// Compute Dtype
54+
ScalarType compute_type = utils::get_compute_type(common_type);
55+
56+
// @lint-ignore CLANGTIDY facebook-hte-CArray
57+
static constexpr const char op_name[] = "sub.out";
58+
59+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
60+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
61+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63+
return val_a - val_alpha * val_b;
64+
},
65+
ctx,
66+
a,
67+
utils::SupportedTensorDtypes::REALHBF16,
68+
b,
69+
utils::SupportedTensorDtypes::REALHBF16,
70+
out,
71+
utils::SupportedTensorDtypes::REALHBF16);
72+
});
73+
74+
return out;
75+
}
76+
77+
Tensor& sub_scalar_out_impl(
78+
KernelRuntimeContext& ctx,
79+
const Tensor& a,
80+
const Scalar& b,
81+
const Scalar& alpha,
82+
Tensor& out) {
83+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
84+
85+
// Check alpha type
86+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
87+
88+
// Common Dtype
89+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
90+
91+
// Check Common Dtype
92+
ET_KERNEL_CHECK(
93+
ctx,
94+
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
95+
InvalidArgument,
96+
out);
97+
98+
// Check Dim Order
99+
ET_KERNEL_CHECK(
100+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
101+
102+
// Resize
103+
ET_KERNEL_CHECK(
104+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
105+
106+
// Compute Dtype
107+
ScalarType compute_type = utils::get_compute_type(common_type);
108+
109+
// @lint-ignore CLANGTIDY facebook-hte-CArray
110+
static constexpr const char op_name[] = "sub.Scalar_out";
111+
112+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
113+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
114+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
115+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
116+
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
117+
return val_a - val_alpha * val_b;
118+
},
119+
ctx,
120+
a,
121+
utils::SupportedTensorDtypes::REALHBF16,
122+
out,
123+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
124+
});
125+
126+
return out;
127+
}
128+
129+
} // namespace native
130+
} // namespace executor
131+
} // namespace torch

kernels/portable/cpu/op_sub_impl.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
Tensor& sub_out_impl(
19+
KernelRuntimeContext& ctx,
20+
const Tensor& a,
21+
const Tensor& b,
22+
const Scalar& alpha,
23+
Tensor& out);
24+
25+
Tensor& sub_scalar_out_impl(
26+
KernelRuntimeContext& ctx,
27+
const Tensor& a,
28+
const Scalar& b,
29+
const Scalar& alpha,
30+
Tensor& out);
31+
32+
} // namespace native
33+
} // namespace executor
34+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,26 @@ def define_common_targets():
151151
],
152152
)
153153

154+
runtime.cxx_library(
155+
name = "op_sub_impl",
156+
srcs = ["op_sub_impl.cpp"],
157+
exported_headers = ["op_sub_impl.h"],
158+
visibility = [
159+
"//executorch/kernels/portable/cpu/...",
160+
"//executorch/kernels/optimized/cpu/...",
161+
"//executorch/kernels/portable/test/...",
162+
"@EXECUTORCH_CLIENTS",
163+
],
164+
exported_deps = [
165+
"//executorch/kernels/portable/cpu/util:broadcast_util",
166+
"//executorch/kernels/portable/cpu/util:dtype_util",
167+
"//executorch/kernels/portable/cpu/util:elementwise_util",
168+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
169+
"//executorch/kernels/portable/cpu:scalar_utils",
170+
"//executorch/runtime/kernel:kernel_includes",
171+
],
172+
)
173+
154174
# The following will not participate in dtype selective build because
155175
# they are refactored such to be used in optimized op implementations as well
156176
# and we have not enabled selective build for optimized ops.
@@ -165,6 +185,7 @@ def define_common_targets():
165185
"//executorch/kernels/portable/cpu:op_add_impl",
166186
"//executorch/kernels/portable/cpu:op_div_impl",
167187
"//executorch/kernels/portable/cpu:op_mul_impl",
188+
"//executorch/kernels/portable/cpu:op_sub_impl",
168189
],
169190
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
170191
)

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,10 +1138,7 @@ ATEN_OPS = (
11381138
op_target(
11391139
name = "op_sub",
11401140
deps = [
1141-
":scalar_utils",
1142-
"//executorch/kernels/portable/cpu/util:broadcast_util",
1143-
"//executorch/kernels/portable/cpu/util:dtype_util",
1144-
"//executorch/kernels/portable/cpu/util:elementwise_util",
1141+
"//executorch/kernels/portable/cpu:op_sub_impl",
11451142
],
11461143
),
11471144
op_target(
@@ -1264,4 +1261,4 @@ def portable_source_list():
12641261

12651262
def portable_header_list():
12661263
"""All the header file names from //executorch/kernels/portable/cpu/"""
1267-
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h"]
1264+
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h", "op_add_impl.h", "op_sub_impl.h"]

0 commit comments

Comments
 (0)