Skip to content

Commit 8bf5ffd

Browse files
[EE/BE][ET][Portable] Move scalar_to utils to scalar_utils.h (#12035)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12009 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/116/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/116/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/116/orig @diff-train-skip-merge Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent 72303a6 commit 8bf5ffd

File tree

3 files changed

+29
-29
lines changed

3 files changed

+29
-29
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#pragma once
1010

11-
#include <algorithm>
1211
#include <cmath>
1312
#include <limits>
1413

@@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
261260
return false;
262261
}
263262

263+
/*
264+
* Convert Scalar to C++ type
265+
*/
266+
267+
template <typename T>
268+
T scalar_to(const Scalar& s) {
269+
if (s.isBoolean()) {
270+
return static_cast<T>(s.to<bool>());
271+
} else if (s.isFloatingPoint()) {
272+
return static_cast<T>(s.to<double>());
273+
} else {
274+
return static_cast<T>(s.to<int64_t>());
275+
}
276+
}
277+
278+
template <>
279+
inline double scalar_to<double>(const Scalar& s) {
280+
return s.isFloatingPoint() ? s.to<double>()
281+
: static_cast<double>(s.to<int64_t>());
282+
}
283+
284+
template <>
285+
inline int64_t scalar_to<int64_t>(const Scalar& s) {
286+
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
287+
: s.to<int64_t>();
288+
}
289+
264290
} // namespace utils
265291
} // namespace native
266292
} // namespace executor

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1213
#include <executorch/kernels/portable/cpu/selective_build.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1415
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
@@ -28,34 +29,6 @@ namespace torch {
2829
namespace executor {
2930
namespace native {
3031
namespace utils {
31-
32-
/*
33-
* Convert Scalar to C++ type
34-
*/
35-
36-
template <typename T>
37-
T scalar_to(const Scalar& s) {
38-
if (s.isBoolean()) {
39-
return static_cast<T>(s.to<bool>());
40-
} else if (s.isFloatingPoint()) {
41-
return static_cast<T>(s.to<double>());
42-
} else {
43-
return static_cast<T>(s.to<int64_t>());
44-
}
45-
}
46-
47-
template <>
48-
inline double scalar_to<double>(const Scalar& s) {
49-
return s.isFloatingPoint() ? s.to<double>()
50-
: static_cast<double>(s.to<int64_t>());
51-
}
52-
53-
template <>
54-
inline int64_t scalar_to<int64_t>(const Scalar& s) {
55-
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
56-
: s.to<int64_t>();
57-
}
58-
5932
namespace internal {
6033
/**
6134
* Causes these utility functions to make sure to respect Tensor

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def define_common_targets():
117117
"//executorch/runtime/kernel:kernel_runtime_context",
118118
"//executorch/kernels/portable/cpu:scalar_utils",
119119
"//executorch/extension/threadpool:threadpool",
120+
"//executorch/kernels/portable/cpu:scalar_utils",
120121
],
121122
deps = [
122123
"//executorch/runtime/kernel:kernel_includes",

0 commit comments

Comments
 (0)