Skip to content

Commit 19f01e0

Browse files
committed
outline input validation in elementwise_util
This seems to improve code size. Comparing output of test/build_optimized_size_test.sh before/after this change: (I've edited out the unchanged "no ops" test case.) before: ``` ExecuTorch with portable ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 2004816 Jun 26 14:02 cmake-out/test/size_test_all_ops __TEXT __DATA __OBJC others dec hex 1441792 65536 0 4295524352 4297031680 1001f8000 ExecuTorch with optimized ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 6746968 Jun 26 14:02 cmake-out/test/size_test_all_optimized_ops __TEXT __DATA __OBJC others dec hex 4947968 65536 0 4296753152 4301766656 10067c000 ``` after: ``` ExecuTorch with portable ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 1989392 Jun 26 14:37 cmake-out/test/size_test_all_ops __TEXT __DATA __OBJC others dec hex 1425408 65536 0 4295524352 4297015296 1001f4000 ExecuTorch with optimized ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 6731784 Jun 26 14:37 cmake-out/test/size_test_all_optimized_ops __TEXT __DATA __OBJC others dec hex 4931584 65536 0 4296753152 4301750272 100678000 ``` for test/build_size_test.sh, we see a smaller improvment which is reflected in the file sizes but not the `size` command output, probably because the latter is rounded up to the next page: before: ``` ExecuTorch with portable ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 2696640 Jun 26 13:44 cmake-out/test/size_test_all_ops __TEXT __DATA __OBJC others dec hex 2162688 65536 0 4295491584 4297719808 1002a0000 ``` after: ``` ExecuTorch with portable ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 2695344 Jun 26 14:33 cmake-out/test/size_test_all_ops __TEXT __DATA __OBJC others dec hex 2162688 65536 0 4295491584 4297719808 1002a0000 ``` ghstack-source-id: 249444e ghstack-comment-id: 3010230941 Pull-Request-resolved: #12031
1 parent 392ea4d commit 19f01e0

File tree

3 files changed

+114
-25
lines changed

3 files changed

+114
-25
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10+
11+
namespace torch::executor::native::utils::internal {
12+
13+
template <typename... Args>
14+
inline bool validate_elementwise_fn_inputs_impl(
15+
KernelRuntimeContext& ctx,
16+
const Tensor& out,
17+
SupportedTensorDtypes out_dtypes,
18+
ScalarType compute_type,
19+
Args... inputs) {
20+
static_assert(
21+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
22+
...));
23+
const auto check_input_dtype = [](auto input, auto compute_type) {
24+
return internal::check_tensor_dtype(
25+
*input.first, input.second, compute_type);
26+
};
27+
ET_KERNEL_CHECK(
28+
ctx,
29+
(check_input_dtype(inputs, compute_type) && ...) &&
30+
internal::check_tensor_dtype(out, out_dtypes, compute_type),
31+
InvalidArgument,
32+
false);
33+
34+
return true;
35+
}
36+
37+
bool validate_elementwise_fn_inputs(
38+
KernelRuntimeContext& ctx,
39+
const Tensor& out,
40+
SupportedTensorDtypes out_dtypes,
41+
ScalarType compute_type,
42+
std::pair<const Tensor*, SupportedTensorDtypes> input) {
43+
return validate_elementwise_fn_inputs_impl(
44+
ctx,
45+
out,
46+
out_dtypes,
47+
compute_type,
48+
input);
49+
}
50+
51+
bool validate_elementwise_fn_inputs(
52+
KernelRuntimeContext& ctx,
53+
const Tensor& out,
54+
SupportedTensorDtypes out_dtypes,
55+
ScalarType compute_type,
56+
std::pair<const Tensor*, SupportedTensorDtypes> input0,
57+
std::pair<const Tensor*, SupportedTensorDtypes> input1) {
58+
return validate_elementwise_fn_inputs_impl(
59+
ctx,
60+
out,
61+
out_dtypes,
62+
compute_type,
63+
input0,
64+
input1);
65+
}
66+
67+
bool validate_elementwise_fn_inputs(
68+
KernelRuntimeContext& ctx,
69+
const Tensor& out,
70+
SupportedTensorDtypes out_dtypes,
71+
ScalarType compute_type,
72+
std::pair<const Tensor*, SupportedTensorDtypes> input0,
73+
std::pair<const Tensor*, SupportedTensorDtypes> input1,
74+
std::pair<const Tensor*, SupportedTensorDtypes> input2) {
75+
return validate_elementwise_fn_inputs_impl(
76+
ctx,
77+
out,
78+
out_dtypes,
79+
compute_type,
80+
input0,
81+
input1,
82+
input2);
83+
}
84+
85+
86+
} // namespace torch::executor::native::utils::internal

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -220,30 +220,29 @@ inline void dtype_specialized_elementwise_fn_impl(
220220
});
221221
}
222222

223-
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
224-
inline bool validate_elementwise_fn_inputs(
225-
const Op& compute_fun,
223+
bool validate_elementwise_fn_inputs(
226224
KernelRuntimeContext& ctx,
227225
const Tensor& out,
228226
SupportedTensorDtypes out_dtypes,
229-
Args... inputs) {
230-
static_assert(
231-
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
232-
...));
233-
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
234-
const auto check_input_dtype = [](auto input, auto compute_type) {
235-
return internal::check_tensor_dtype(
236-
*input.first, input.second, compute_type);
237-
};
238-
ET_KERNEL_CHECK(
239-
ctx,
240-
(check_input_dtype(inputs, compute_type) && ...) &&
241-
internal::check_tensor_dtype(out, out_dtypes, compute_type),
242-
InvalidArgument,
243-
false);
227+
ScalarType compute_type,
228+
std::pair<const Tensor*, SupportedTensorDtypes> input);
244229

245-
return true;
246-
}
230+
bool validate_elementwise_fn_inputs(
231+
KernelRuntimeContext& ctx,
232+
const Tensor& out,
233+
SupportedTensorDtypes out_dtypes,
234+
ScalarType compute_type,
235+
std::pair<const Tensor*, SupportedTensorDtypes> input0,
236+
std::pair<const Tensor*, SupportedTensorDtypes> input1);
237+
238+
bool validate_elementwise_fn_inputs(
239+
KernelRuntimeContext& ctx,
240+
const Tensor& out,
241+
SupportedTensorDtypes out_dtypes,
242+
ScalarType compute_type,
243+
std::pair<const Tensor*, SupportedTensorDtypes> input0,
244+
std::pair<const Tensor*, SupportedTensorDtypes> input1,
245+
std::pair<const Tensor*, SupportedTensorDtypes> input2);
247246

248247
template <
249248
typename CTYPE_COMPUTE,
@@ -314,8 +313,9 @@ inline void apply_elementwise_fn_runtime_out_dtypes(
314313
const Tensor& out,
315314
SupportedTensorDtypes out_dtypes,
316315
Args... inputs) {
317-
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
318-
compute_fun, ctx, out, out_dtypes, inputs...);
316+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
317+
const bool inputs_valid = validate_elementwise_fn_inputs(
318+
ctx, out, out_dtypes, compute_type, inputs...);
319319
if (!inputs_valid) {
320320
return;
321321
}
@@ -339,13 +339,13 @@ inline void apply_elementwise_fn(
339339
KernelRuntimeContext& ctx,
340340
const Tensor& out,
341341
Args... inputs) {
342-
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
343-
compute_fun, ctx, out, out_dtypes, inputs...);
342+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
343+
const bool inputs_valid = validate_elementwise_fn_inputs(
344+
ctx, out, out_dtypes, compute_type, inputs...);
344345
if (!inputs_valid) {
345346
return;
346347
}
347348

348-
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
349349
if constexpr (should_include_kernel_dtype(op_name, compute_type)) {
350350
const bool all_inputs_compute_dtype =
351351
((inputs.first->scalar_type() == compute_type) && ...);

kernels/portable/cpu/util/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def define_common_targets():
104104

105105
runtime.cxx_library(
106106
name = "elementwise_util",
107+
srcs = [
108+
"elementwise_util.cpp",
109+
],
107110
exported_headers = [
108111
"elementwise_util.h",
109112
],

0 commit comments

Comments
 (0)