File tree Expand file tree Collapse file tree 3 files changed +29
-29
lines changed Expand file tree Collapse file tree 3 files changed +29
-29
lines changed Original file line number Diff line number Diff line change 8
8
9
9
#pragma once
10
10
11
- #include < algorithm>
12
11
#include < cmath>
13
12
#include < limits>
14
13
@@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
261
260
return false ;
262
261
}
263
262
263
+ /*
264
+ * Convert Scalar to C++ type
265
+ */
266
+
267
+ template <typename T>
268
+ T scalar_to (const Scalar& s) {
269
+ if (s.isBoolean ()) {
270
+ return static_cast <T>(s.to <bool >());
271
+ } else if (s.isFloatingPoint ()) {
272
+ return static_cast <T>(s.to <double >());
273
+ } else {
274
+ return static_cast <T>(s.to <int64_t >());
275
+ }
276
+ }
277
+
278
+ template <>
279
+ inline double scalar_to<double >(const Scalar& s) {
280
+ return s.isFloatingPoint () ? s.to <double >()
281
+ : static_cast <double >(s.to <int64_t >());
282
+ }
283
+
284
+ template <>
285
+ inline int64_t scalar_to<int64_t >(const Scalar& s) {
286
+ return s.isFloatingPoint () ? static_cast <int64_t >(s.to <double >())
287
+ : s.to <int64_t >();
288
+ }
289
+
264
290
} // namespace utils
265
291
} // namespace native
266
292
} // namespace executor
Original file line number Diff line number Diff line change 9
9
#pragma once
10
10
11
11
#include < c10/util/irange.h>
12
+ #include < executorch/kernels/portable/cpu/scalar_utils.h>
12
13
#include < executorch/kernels/portable/cpu/selective_build.h>
13
14
#include < executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
14
15
#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
@@ -28,34 +29,6 @@ namespace torch {
28
29
namespace executor {
29
30
namespace native {
30
31
namespace utils {
31
-
32
- /*
33
- * Convert Scalar to C++ type
34
- */
35
-
36
- template <typename T>
37
- T scalar_to (const Scalar& s) {
38
- if (s.isBoolean ()) {
39
- return static_cast <T>(s.to <bool >());
40
- } else if (s.isFloatingPoint ()) {
41
- return static_cast <T>(s.to <double >());
42
- } else {
43
- return static_cast <T>(s.to <int64_t >());
44
- }
45
- }
46
-
47
- template <>
48
- inline double scalar_to<double >(const Scalar& s) {
49
- return s.isFloatingPoint () ? s.to <double >()
50
- : static_cast <double >(s.to <int64_t >());
51
- }
52
-
53
- template <>
54
- inline int64_t scalar_to<int64_t >(const Scalar& s) {
55
- return s.isFloatingPoint () ? static_cast <int64_t >(s.to <double >())
56
- : s.to <int64_t >();
57
- }
58
-
59
32
namespace internal {
60
33
/* *
61
34
* Causes these utility functions to make sure to respect Tensor
Original file line number Diff line number Diff line change @@ -117,6 +117,7 @@ def define_common_targets():
117
117
"//executorch/runtime/kernel:kernel_runtime_context" ,
118
118
"//executorch/kernels/portable/cpu:scalar_utils" ,
119
119
"//executorch/extension/threadpool:threadpool" ,
120
+ "//executorch/kernels/portable/cpu:scalar_utils" ,
120
121
],
121
122
deps = [
122
123
"//executorch/runtime/kernel:kernel_includes" ,
You can’t perform that action at this time.
0 commit comments