From 64edb65f2fcceee9a9d59180f37c3fff4a975b30 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 11:03:27 -0700 Subject: [PATCH] [EE/BE][ET][Portable] Move scalar_to utils to scalar_utils.h Differential Revision: [D75962656](https://our.internmc.facebook.com/intern/diff/D75962656/) [ghstack-poisoned] --- kernels/portable/cpu/scalar_utils.h | 28 ++++++++++++++++++- kernels/portable/cpu/util/elementwise_util.h | 29 +------------------- kernels/portable/cpu/util/targets.bzl | 2 +- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 02700804819..162f96ba85d 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -8,7 +8,6 @@ #pragma once -#include #include #include @@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) { return false; } +/* + * Convert Scalar to C++ type + */ + +template +T scalar_to(const Scalar& s) { + if (s.isBoolean()) { + return static_cast(s.to()); + } else if (s.isFloatingPoint()) { + return static_cast(s.to()); + } else { + return static_cast(s.to()); + } +} + +template <> +inline double scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? s.to() + : static_cast(s.to()); +} + +template <> +inline int64_t scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? static_cast(s.to()) + : s.to(); +} + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 948da50fdd4..6adf81f70e3 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -27,34 +28,6 @@ namespace torch { namespace executor { namespace native { namespace utils { - -/* - * Convert Scalar to C++ type - */ - -template -T scalar_to(const Scalar& s) { - if (s.isBoolean()) { - return static_cast(s.to()); - } else if (s.isFloatingPoint()) { - return static_cast(s.to()); - } else { - return static_cast(s.to()); - } -} - -template <> -inline double scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? s.to() - : static_cast(s.to()); -} - -template <> -inline int64_t scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? static_cast(s.to()) - : s.to(); -} - namespace internal { /** * Causes these utility functions to make sure to respect Tensor diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 1523fcfe706..ef3a878fd70 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -116,9 +116,9 @@ def define_common_targets(): "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", "//executorch/extension/threadpool:threadpool", + "//executorch/kernels/portable/cpu:scalar_utils", ], deps = [ - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"],