Skip to content

Commit 11443fa

Browse files
[EE/BE][ET][Portable] Move scalar_to utils to scalar_utils.h
Differential Revision: [D75962656](https://our.internmc.facebook.com/intern/diff/D75962656/) ghstack-source-id: 292676104 Pull Request resolved: #12009
1 parent 85cf6ce commit 11443fa

File tree

3 files changed

+29
-30
lines changed

3 files changed

+29
-30
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#pragma once
1010

11-
#include <algorithm>
1211
#include <cmath>
1312
#include <limits>
1413

@@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
261260
return false;
262261
}
263262

263+
/*
264+
* Convert Scalar to C++ type
265+
*/
266+
267+
template <typename T>
268+
T scalar_to(const Scalar& s) {
269+
if (s.isBoolean()) {
270+
return static_cast<T>(s.to<bool>());
271+
} else if (s.isFloatingPoint()) {
272+
return static_cast<T>(s.to<double>());
273+
} else {
274+
return static_cast<T>(s.to<int64_t>());
275+
}
276+
}
277+
278+
template <>
279+
inline double scalar_to<double>(const Scalar& s) {
280+
return s.isFloatingPoint() ? s.to<double>()
281+
: static_cast<double>(s.to<int64_t>());
282+
}
283+
284+
template <>
285+
inline int64_t scalar_to<int64_t>(const Scalar& s) {
286+
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
287+
: s.to<int64_t>();
288+
}
289+
264290
} // namespace utils
265291
} // namespace native
266292
} // namespace executor

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1213
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
@@ -27,34 +28,6 @@ namespace torch {
2728
namespace executor {
2829
namespace native {
2930
namespace utils {
30-
31-
/*
32-
* Convert Scalar to C++ type
33-
*/
34-
35-
template <typename T>
36-
T scalar_to(const Scalar& s) {
37-
if (s.isBoolean()) {
38-
return static_cast<T>(s.to<bool>());
39-
} else if (s.isFloatingPoint()) {
40-
return static_cast<T>(s.to<double>());
41-
} else {
42-
return static_cast<T>(s.to<int64_t>());
43-
}
44-
}
45-
46-
template <>
47-
inline double scalar_to<double>(const Scalar& s) {
48-
return s.isFloatingPoint() ? s.to<double>()
49-
: static_cast<double>(s.to<int64_t>());
50-
}
51-
52-
template <>
53-
inline int64_t scalar_to<int64_t>(const Scalar& s) {
54-
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
55-
: s.to<int64_t>();
56-
}
57-
5831
namespace internal {
5932
/**
6033
* Causes these utility functions to make sure to respect Tensor

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def define_common_targets():
116116
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
117117
"//executorch/runtime/kernel:kernel_runtime_context",
118118
"//executorch/extension/threadpool:threadpool",
119+
"//executorch/kernels/portable/cpu:scalar_utils",
119120
],
120121
deps = [
121-
"//executorch/kernels/portable/cpu:scalar_utils",
122122
"//executorch/runtime/kernel:kernel_includes",
123123
],
124124
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],

0 commit comments

Comments
 (0)