Skip to content

Commit baa8660

Browse files
committed
Update on "[Executorch] Renable operator optimization flags"
Previous attempt at this resulted in revert due to app size increase. Much of this was due to op_div exploding. Two diffs underneath solve this issue Differential Revision: [D65606666](https://our.internmc.facebook.com/intern/diff/D65606666/) [ghstack-poisoned]
2 parents cceac32 + 820d440 commit baa8660

File tree

7 files changed

+196
-152
lines changed

7 files changed

+196
-152
lines changed

kernels/optimized/cpu/op_sub.cpp

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1010
#include <executorch/kernels/optimized/vec/functional.h>
1111
#include <executorch/kernels/optimized/vec/vec.h>
12+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
1213
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
@@ -210,35 +211,7 @@ Tensor& opt_sub_out(
210211
}
211212
});
212213
} else {
213-
ScalarType common_type =
214-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
215-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
216-
217-
ET_KERNEL_CHECK(
218-
ctx,
219-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
220-
InvalidArgument,
221-
out);
222-
223-
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
224-
ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
225-
using CTYPE_IN = typename torch::executor::
226-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
227-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
228-
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
229-
CTYPE_IN alpha_val;
230-
ET_KERNEL_CHECK(
231-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
232-
233-
SubInner<
234-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
235-
CTYPE_A,
236-
CTYPE_B,
237-
CTYPE_IN,
238-
CTYPE_OUT>::run(a, b, alpha_val, out);
239-
});
240-
});
241-
});
214+
sub_out_impl(ctx, a, b, alpha, out);
242215
}
243216

244217
return out;
@@ -290,31 +263,7 @@ Tensor& opt_sub_scalar_out(
290263
});
291264
});
292265
} else {
293-
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
294-
ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
295-
b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
296-
ET_SWITCH_REAL_TYPES(
297-
common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
298-
ET_SWITCH_REALH_TYPES(
299-
out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
300-
CTYPE_B b_val;
301-
ET_EXTRACT_SCALAR(b, b_val);
302-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
303-
CTYPE_IN alpha_val;
304-
ET_EXTRACT_SCALAR(alpha, alpha_val);
305-
306-
const size_t n = a.numel();
307-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
308-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
309-
for (auto i = 0; i < n; ++i) {
310-
out_data[i] = static_cast<CTYPE_OUT>(
311-
static_cast<CTYPE_IN>(a_data[i]) -
312-
alpha_val * b_casted);
313-
}
314-
});
315-
});
316-
});
317-
});
266+
sub_scalar_out_impl(ctx, a, b, alpha, out);
318267
}
319268

320269
return out;

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ _OPTIMIZED_ATEN_OPS = (
8888
name = "op_sub",
8989
deps = [
9090
":binary_ops",
91+
"//executorch/kernels/portable/cpu:op_sub_impl",
9192
"//executorch/kernels/portable/cpu:scalar_utils",
9293
"//executorch/kernels/portable/cpu/util:broadcast_util",
9394
],

kernels/portable/cpu/op_sub.cpp

Lines changed: 4 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212
#include <executorch/runtime/platform/assert.h>
1313

14+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
15+
1416
namespace torch {
1517
namespace executor {
1618
namespace native {
@@ -21,55 +23,7 @@ Tensor& sub_out(
2123
const Tensor& b,
2224
const Scalar& alpha,
2325
Tensor& out) {
24-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
25-
26-
// Check alpha type
27-
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
28-
29-
// Common Dtype
30-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
31-
32-
// Check Common Dtype
33-
ET_KERNEL_CHECK(
34-
ctx,
35-
(canCast(common_type, out.scalar_type()) &&
36-
canCast(alpha_type, common_type)),
37-
InvalidArgument,
38-
out);
39-
40-
// Check Dim Order
41-
ET_KERNEL_CHECK(
42-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43-
44-
// Resize
45-
ET_KERNEL_CHECK(
46-
ctx,
47-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
48-
InvalidArgument,
49-
out);
50-
51-
// Compute Dtype
52-
ScalarType compute_type = utils::get_compute_type(common_type);
53-
54-
// @lint-ignore CLANGTIDY facebook-hte-CArray
55-
static constexpr const char op_name[] = "sub.out";
56-
57-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
59-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
60-
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
61-
return val_a - val_alpha * val_b;
62-
},
63-
ctx,
64-
a,
65-
utils::SupportedTensorDtypes::REALHBF16,
66-
b,
67-
utils::SupportedTensorDtypes::REALHBF16,
68-
out,
69-
utils::SupportedTensorDtypes::REALHBF16);
70-
});
71-
72-
return out;
26+
return sub_out_impl(ctx, a, b, alpha, out);
7327
}
7428

7529
Tensor& sub_scalar_out(
@@ -78,50 +32,7 @@ Tensor& sub_scalar_out(
7832
const Scalar& b,
7933
const Scalar& alpha,
8034
Tensor& out) {
81-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
82-
83-
// Check alpha type
84-
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
85-
86-
// Common Dtype
87-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
88-
89-
// Check Common Dtype
90-
ET_KERNEL_CHECK(
91-
ctx,
92-
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
93-
InvalidArgument,
94-
out);
95-
96-
// Check Dim Order
97-
ET_KERNEL_CHECK(
98-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
99-
100-
// Resize
101-
ET_KERNEL_CHECK(
102-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
103-
104-
// Compute Dtype
105-
ScalarType compute_type = utils::get_compute_type(common_type);
106-
107-
// @lint-ignore CLANGTIDY facebook-hte-CArray
108-
static constexpr const char op_name[] = "sub.Scalar_out";
109-
110-
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
111-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
112-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
113-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
114-
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
115-
return val_a - val_alpha * val_b;
116-
},
117-
ctx,
118-
a,
119-
utils::SupportedTensorDtypes::REALHBF16,
120-
out,
121-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
122-
});
123-
124-
return out;
35+
return sub_scalar_out_impl(ctx, a, b, alpha, out);
12536
}
12637

12738
} // namespace native
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
#include <executorch/kernels/portable/cpu/op_sub_impl.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
Tensor& sub_out_impl(
21+
KernelRuntimeContext& ctx,
22+
const Tensor& a,
23+
const Tensor& b,
24+
const Scalar& alpha,
25+
Tensor& out) {
26+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
27+
28+
// Check alpha type
29+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
30+
31+
// Common Dtype
32+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
33+
34+
// Check Common Dtype
35+
ET_KERNEL_CHECK(
36+
ctx,
37+
(canCast(common_type, out.scalar_type()) &&
38+
canCast(alpha_type, common_type)),
39+
InvalidArgument,
40+
out);
41+
42+
// Check Dim Order
43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
45+
46+
// Resize
47+
ET_KERNEL_CHECK(
48+
ctx,
49+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
50+
InvalidArgument,
51+
out);
52+
53+
// Compute Dtype
54+
ScalarType compute_type = utils::get_compute_type(common_type);
55+
56+
// @lint-ignore CLANGTIDY facebook-hte-CArray
57+
static constexpr const char op_name[] = "sub.out";
58+
59+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
60+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
61+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63+
return val_a - val_alpha * val_b;
64+
},
65+
ctx,
66+
a,
67+
utils::SupportedTensorDtypes::REALHBF16,
68+
b,
69+
utils::SupportedTensorDtypes::REALHBF16,
70+
out,
71+
utils::SupportedTensorDtypes::REALHBF16);
72+
});
73+
74+
return out;
75+
}
76+
77+
Tensor& sub_scalar_out_impl(
78+
KernelRuntimeContext& ctx,
79+
const Tensor& a,
80+
const Scalar& b,
81+
const Scalar& alpha,
82+
Tensor& out) {
83+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
84+
85+
// Check alpha type
86+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
87+
88+
// Common Dtype
89+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
90+
91+
// Check Common Dtype
92+
ET_KERNEL_CHECK(
93+
ctx,
94+
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
95+
InvalidArgument,
96+
out);
97+
98+
// Check Dim Order
99+
ET_KERNEL_CHECK(
100+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
101+
102+
// Resize
103+
ET_KERNEL_CHECK(
104+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
105+
106+
// Compute Dtype
107+
ScalarType compute_type = utils::get_compute_type(common_type);
108+
109+
// @lint-ignore CLANGTIDY facebook-hte-CArray
110+
static constexpr const char op_name[] = "sub.Scalar_out";
111+
112+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
113+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
114+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
115+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
116+
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
117+
return val_a - val_alpha * val_b;
118+
},
119+
ctx,
120+
a,
121+
utils::SupportedTensorDtypes::REALHBF16,
122+
out,
123+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
124+
});
125+
126+
return out;
127+
}
128+
129+
} // namespace native
130+
} // namespace executor
131+
} // namespace torch

kernels/portable/cpu/op_sub_impl.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
Tensor& sub_out_impl(
19+
KernelRuntimeContext& ctx,
20+
const Tensor& a,
21+
const Tensor& b,
22+
const Scalar& alpha,
23+
Tensor& out);
24+
25+
Tensor& sub_scalar_out_impl(
26+
KernelRuntimeContext& ctx,
27+
const Tensor& a,
28+
const Scalar& b,
29+
const Scalar& alpha,
30+
Tensor& out);
31+
32+
} // namespace native
33+
} // namespace executor
34+
} // namespace torch

0 commit comments

Comments
 (0)