diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 02700804819..162f96ba85d 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -8,7 +8,6 @@ #pragma once -#include #include #include @@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) { return false; } +/* + * Convert Scalar to C++ type + */ + +template +T scalar_to(const Scalar& s) { + if (s.isBoolean()) { + return static_cast(s.to()); + } else if (s.isFloatingPoint()) { + return static_cast(s.to()); + } else { + return static_cast(s.to()); + } +} + +template <> +inline double scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? s.to() + : static_cast(s.to()); +} + +template <> +inline int64_t scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? static_cast(s.to()) + : s.to(); +} + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index a376a89747b..0a2d4bfc89a 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -28,34 +29,6 @@ namespace torch { namespace executor { namespace native { namespace utils { - -/* - * Convert Scalar to C++ type - */ - -template -T scalar_to(const Scalar& s) { - if (s.isBoolean()) { - return static_cast(s.to()); - } else if (s.isFloatingPoint()) { - return static_cast(s.to()); - } else { - return static_cast(s.to()); - } -} - -template <> -inline double scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? s.to() - : static_cast(s.to()); -} - -template <> -inline int64_t scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? static_cast(s.to()) - : s.to(); -} - namespace internal { /** * Causes these utility functions to make sure to respect Tensor diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 44b95aa55c4..41633be2183 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -117,6 +117,7 @@ def define_common_targets(): "//executorch/runtime/kernel:kernel_runtime_context", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/extension/threadpool:threadpool", + "//executorch/kernels/portable/cpu:scalar_utils", ], deps = [ "//executorch/runtime/kernel:kernel_includes",