From 11443faf2bb7ffaad21add4e2eac6aa863d4a15f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 11:03:30 -0700 Subject: [PATCH 1/4] [EE/BE][ET][Portable] Move scalar_to utils to scalar_utils.h Differential Revision: [D75962656](https://our.internmc.facebook.com/intern/diff/D75962656/) ghstack-source-id: 292676104 Pull Request resolved: https://github.com/pytorch/executorch/pull/12009 --- kernels/portable/cpu/scalar_utils.h | 28 ++++++++++++++++++- kernels/portable/cpu/util/elementwise_util.h | 29 +------------------- kernels/portable/cpu/util/targets.bzl | 2 +- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 02700804819..162f96ba85d 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -8,7 +8,6 @@ #pragma once -#include #include #include @@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) { return false; } +/* + * Convert Scalar to C++ type + */ + +template +T scalar_to(const Scalar& s) { + if (s.isBoolean()) { + return static_cast(s.to()); + } else if (s.isFloatingPoint()) { + return static_cast(s.to()); + } else { + return static_cast(s.to()); + } +} + +template <> +inline double scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? s.to() + : static_cast(s.to()); +} + +template <> +inline int64_t scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? static_cast(s.to()) + : s.to(); +} + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 948da50fdd4..6adf81f70e3 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -27,34 +28,6 @@ namespace torch { namespace executor { namespace native { namespace utils { - -/* - * Convert Scalar to C++ type - */ - -template -T scalar_to(const Scalar& s) { - if (s.isBoolean()) { - return static_cast(s.to()); - } else if (s.isFloatingPoint()) { - return static_cast(s.to()); - } else { - return static_cast(s.to()); - } -} - -template <> -inline double scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? s.to() - : static_cast(s.to()); -} - -template <> -inline int64_t scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? static_cast(s.to()) - : s.to(); -} - namespace internal { /** * Causes these utility functions to make sure to respect Tensor diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 1523fcfe706..ef3a878fd70 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -116,9 +116,9 @@ def define_common_targets(): "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", "//executorch/extension/threadpool:threadpool", + "//executorch/kernels/portable/cpu:scalar_utils", ], deps = [ - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], From 4a3561e634e8f4b61cf3b4e1c1c891001ce04471 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 11:03:32 -0700 Subject: [PATCH 2/4] [EE/BE][ET][Portable] Eliminate usage of ET_SWITCH_SCALAR_OBJ_TYPES in portable kernels Differential Revision: [D75981642](https://our.internmc.facebook.com/intern/diff/D75981642/) ghstack-source-id: 292779980 Pull Request resolved: https://github.com/pytorch/executorch/pull/12010 --- kernels/portable/cpu/op_clamp.cpp | 43 ++++++++++----------- kernels/portable/cpu/op_constant_pad_nd.cpp | 11 +----- kernels/portable/cpu/op_fill.cpp | 8 +--- kernels/portable/cpu/op_full.cpp | 19 +++------ kernels/portable/cpu/op_hardtanh.cpp | 17 +------- kernels/portable/cpu/op_leaky_relu.cpp | 9 +---- kernels/portable/cpu/op_scalar_tensor.cpp | 20 +++++----- kernels/portable/cpu/op_scatter.cpp | 12 ++---- kernels/portable/cpu/op_var.cpp | 7 +--- 9 files changed, 47 insertions(+), 99 deletions(-) diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index af082a18e78..c2b9c73f2ea 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor; namespace { -template +template /** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */ -bool is_out_of_bounds(CTYPE_VAL val) { - const CTYPE_CAST val_cast = static_cast(val); +bool is_out_of_bounds(CTYPE_CAST val_cast) { return val_cast < std::numeric_limits::lowest() || val_cast > std::numeric_limits::max(); } @@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds( const char* val_name) { auto is_valid = true; - ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() { - CTYPE_VAL val = 0; - utils::extract_scalar(val_scalar, &val); - if (isIntegralType(out_type, /*includeBool=*/false)) { - ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() { - if (is_out_of_bounds(val)) { - ET_LOG(Error, "%s value out of bounds", val_name); - is_valid = false; - } - }); - } else if (isFloatingType(out_type)) { - ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { - if (std::isfinite(val) && - is_out_of_bounds(val)) { - ET_LOG(Error, "%s value out of bounds", val_name); - is_valid = false; - } - }); - } - }); + if (isIntegralType(out_type, /*includeBool=*/false)) { + const long val_long = utils::scalar_to(val_scalar); + ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() { + if (is_out_of_bounds(val_long)) { + ET_LOG(Error, "%s value out of bounds", val_name); + is_valid = false; + } + }); + } else if (isFloatingType(out_type)) { + ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() { + const double val_double = utils::scalar_to(val_scalar); + if (std::isfinite(val_double) && + is_out_of_bounds(val_double)) { + ET_LOG(Error, "%s value out of bounds", val_name); + is_valid = false; + } + }); + } return is_valid; } diff --git a/kernels/portable/cpu/op_constant_pad_nd.cpp b/kernels/portable/cpu/op_constant_pad_nd.cpp index e1e37e5b4d2..6e643e1b945 100644 --- a/kernels/portable/cpu/op_constant_pad_nd.cpp +++ b/kernels/portable/cpu/op_constant_pad_nd.cpp @@ -183,17 +183,10 @@ Tensor& constant_pad_nd_out( "Failed to resize output tensor."); ScalarType in_type = in.scalar_type(); - ScalarType value_type = utils::get_scalar_dtype(value); ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() { - CTYPE value_v; - ET_SWITCH_SCALAR_OBJ_TYPES( - value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() { - CTYPE_VALUE val = 0; - utils::extract_scalar(value, &val); - value_v = static_cast(val); - }); - constant_pad_nd_out_impl(in, pad, value_v, out); + const CTYPE value_casted = utils::scalar_to(value); + constant_pad_nd_out_impl(in, pad, value_casted, out); }); return out; diff --git a/kernels/portable/cpu/op_fill.cpp b/kernels/portable/cpu/op_fill.cpp index 3ed8557c29e..b985e2f4f07 100644 --- a/kernels/portable/cpu/op_fill.cpp +++ b/kernels/portable/cpu/op_fill.cpp @@ -26,7 +26,6 @@ Tensor& fill_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, a_type == out_type, InvalidArgument, out); @@ -43,12 +42,7 @@ Tensor& fill_scalar_out( "Failed to resize output tensor."); ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] { - CTYPE_A b_casted; - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] { - CTYPE_B b_val = 0; - utils::extract_scalar(b, &b_val); - b_casted = static_cast(b_val); - }); + const CTYPE_A b_casted = utils::scalar_to(b); apply_unary_map_fn( [b_casted](const CTYPE_A val_a) { return b_casted; }, diff --git a/kernels/portable/cpu/op_full.cpp b/kernels/portable/cpu/op_full.cpp index 69b4c8fd150..83ffcad45a6 100644 --- a/kernels/portable/cpu/op_full.cpp +++ b/kernels/portable/cpu/op_full.cpp @@ -24,7 +24,6 @@ Tensor& full_out( Tensor& out) { (void)ctx; - ScalarType val_type = utils::get_scalar_dtype(fill_value); ScalarType out_type = out.scalar_type(); // Resize for dynamic shape @@ -37,18 +36,12 @@ Tensor& full_out( constexpr auto name = "full.out"; - ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] { - CTYPE_VAL val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(fill_value, &val), InvalidArgument, ); - - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { - CTYPE_OUT val_casted = static_cast(val); - auto data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - data_out[i] = val_casted; - } - }); + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + CTYPE_OUT val_casted = utils::scalar_to(fill_value); + auto data_out = out.mutable_data_ptr(); + for (const auto i : c10::irange(out.numel())) { + data_out[i] = val_casted; + } }); return out; diff --git a/kernels/portable/cpu/op_hardtanh.cpp b/kernels/portable/cpu/op_hardtanh.cpp index 97355b0f17c..8ec73b07856 100644 --- a/kernels/portable/cpu/op_hardtanh.cpp +++ b/kernels/portable/cpu/op_hardtanh.cpp @@ -40,26 +40,13 @@ Tensor& hardtanh_out( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ScalarType in_type = in.scalar_type(); - ScalarType min_type = utils::get_scalar_dtype(min); - ScalarType max_type = utils::get_scalar_dtype(max); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { - CTYPE min_casted; - ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() { - CTYPE_MIN min_val = 0; - utils::extract_scalar(min, &min_val); - min_casted = static_cast(min_val); - }); - - CTYPE max_casted; - ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() { - CTYPE_MAX max_val = 0; - utils::extract_scalar(max, &max_val); - max_casted = static_cast(max_val); - }); + const CTYPE min_casted = utils::scalar_to(min); + const CTYPE max_casted = utils::scalar_to(max); apply_unary_map_fn( [min_casted, max_casted](const CTYPE val_in) { diff --git a/kernels/portable/cpu/op_leaky_relu.cpp b/kernels/portable/cpu/op_leaky_relu.cpp index a04365814c7..11860c8d129 100644 --- a/kernels/portable/cpu/op_leaky_relu.cpp +++ b/kernels/portable/cpu/op_leaky_relu.cpp @@ -39,19 +39,12 @@ Tensor& leaky_relu_out( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ScalarType in_type = in.scalar_type(); - ScalarType sc_type = utils::get_scalar_dtype(negative_slope); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() { - CTYPE negative_slope_casted = 0; - ET_SWITCH_SCALAR_OBJ_TYPES( - sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() { - CTYPE_MIN negative_slope_val = 0; - utils::extract_scalar(negative_slope, &negative_slope_val); - negative_slope_casted = static_cast(negative_slope_val); - }); + const CTYPE negative_slope_casted = utils::scalar_to(negative_slope); apply_unary_map_fn( [negative_slope_casted](const CTYPE val_in) { diff --git a/kernels/portable/cpu/op_scalar_tensor.cpp b/kernels/portable/cpu/op_scalar_tensor.cpp index 5be65a2e060..e111a9ac869 100644 --- a/kernels/portable/cpu/op_scalar_tensor.cpp +++ b/kernels/portable/cpu/op_scalar_tensor.cpp @@ -20,19 +20,21 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) { ET_KERNEL_CHECK( ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out); - ScalarType s_type = utils::get_scalar_dtype(s); ScalarType out_type = out.scalar_type(); constexpr auto name = "scalar_tensor.out"; - ET_SWITCH_REAL_TYPES_AND3( - Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() { - CTYPE_S val_s = 0; - utils::extract_scalar(s, &val_s); - out.mutable_data_ptr()[0] = convert(val_s); - }); - }); + if (s.isFloatingPoint() && + executorch::runtime::isIntegralType(out_type, false)) { + ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() { + out.mutable_data_ptr()[0] = + static_cast(utils::scalar_to(s)); + }); + } else { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() { + out.mutable_data_ptr()[0] = utils::scalar_to(s); + }); + } return out; } diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index f8f4b21264e..7de0ec4d5f9 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -151,17 +151,11 @@ Tensor& scatter_value_out( ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - ScalarType val_type = utils::get_scalar_dtype(value); - constexpr auto name = "scatter.value_out"; - ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] { - CTYPE_VAL val; - ET_KERNEL_CHECK(ctx, utils::extract_scalar(value, &val), InvalidArgument, ); - - ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { - scatter_value_helper(in, dim, index, val, out); - }); + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + const CTYPE val = utils::scalar_to(value); + scatter_value_helper(in, dim, index, val, out); }); return out; diff --git a/kernels/portable/cpu/op_var.cpp b/kernels/portable/cpu/op_var.cpp index f09f1d92bc9..a95b3a9f167 100644 --- a/kernels/portable/cpu/op_var.cpp +++ b/kernels/portable/cpu/op_var.cpp @@ -127,12 +127,7 @@ Tensor& var_correction_out( double correction_val = 1; if (correction.has_value()) { - ScalarType corr_type = utils::get_scalar_dtype(correction.value()); - ET_SWITCH_SCALAR_OBJ_TYPES(corr_type, ctx, name, CTYPE_CORR, [&]() { - CTYPE_CORR corr_val = 0; - utils::extract_scalar(correction.value(), &corr_val); - correction_val = static_cast(corr_val); - }); + correction_val = utils::scalar_to(correction.value()); } const size_t num = get_reduced_dim_product(in, dim_list); From 54c2057ff40264fcda12130f4b37e190f7d0ad2c Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 12:51:38 -0700 Subject: [PATCH 3/4] [ET][Portable] Add util to check for overflow when casting scalar Pull Request resolved: https://github.com/pytorch/executorch/pull/12011 ghstack-source-id: 292926438 @exported-using-ghexport Differential Revision: [D77382647](https://our.internmc.facebook.com/intern/diff/D77382647/) --- kernels/portable/cpu/scalar_utils.h | 24 +++++ .../core/portable_type/c10/c10/targets.bzl | 1 + .../portable_type/c10/c10/util/overflows.h | 100 ++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 runtime/core/portable_type/c10/c10/util/overflows.h diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 162f96ba85d..312a663c0e1 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -287,6 +288,29 @@ inline int64_t scalar_to(const Scalar& s) { : s.to(); } +namespace internal { + +template +std::optional check_overflow_cast(From in) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same_v && c10::overflows(in)) { + return std::nullopt; + } + return static_cast(in); +} + +template +std::optional check_overflow_scalar_cast(const Scalar& in) { + if (in.isBoolean()) { + return check_overflow_cast(in.to()); + } else if (in.isFloatingPoint()) { + return check_overflow_cast(in.to()); + } else { + return check_overflow_cast(in.to()); + } +} + +} // namespace internal } // namespace utils } // namespace native } // namespace executor diff --git a/runtime/core/portable_type/c10/c10/targets.bzl b/runtime/core/portable_type/c10/c10/targets.bzl index d64098a85fe..2311fe0216d 100644 --- a/runtime/core/portable_type/c10/c10/targets.bzl +++ b/runtime/core/portable_type/c10/c10/targets.bzl @@ -112,6 +112,7 @@ def define_common_targets(): "util/complex_utils.h", "util/floating_point_utils.h", "util/irange.h", + "util/overflows.h", ], exported_preprocessor_flags = [ "-DC10_USING_CUSTOM_GENERATED_MACROS", diff --git a/runtime/core/portable_type/c10/c10/util/overflows.h b/runtime/core/portable_type/c10/c10/util/overflows.h new file mode 100644 index 00000000000..183a2f62a32 --- /dev/null +++ b/runtime/core/portable_type/c10/c10/util/overflows.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace c10 { +// In some versions of MSVC, there will be a compiler error when building. +// C4146: unary minus operator applied to unsigned type, result still unsigned +// C4804: unsafe use of type 'bool' in operation +// It can be addressed by disabling the following warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#pragma warning(disable : 4804) +#pragma warning(disable : 4018) +#endif + +// The overflow checks may involve float to int conversion which may +// trigger precision loss warning. Re-enable the warning once the code +// is fixed. See T58053069. +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +// bool can be converted to any type. +// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: +// `error: comparison of constant '255' with boolean expression is always false` +// for `f > limit::max()` below +template +std::enable_if_t, bool> overflows( + From /*f*/, + bool strict_unsigned [[maybe_unused]] = false) { + return false; +} + +// skip isnan and isinf check for integral types +template +std::enable_if_t && !std::is_same_v, bool> +overflows(From f, bool strict_unsigned = false) { + using limit = std::numeric_limits::type>; + if constexpr (!limit::is_signed && std::numeric_limits::is_signed) { + // allow for negative numbers to wrap using two's complement arithmetic. + // For example, with uint8, this allows for `a - b` to be treated as + // `a + 255 * b`. + if (!strict_unsigned) { + return greater_than_max(f) || + (c10::is_negative(f) && + -static_cast(f) > static_cast(limit::max())); + } + } + return c10::less_than_lowest(f) || greater_than_max(f); +} + +template +std::enable_if_t, bool> overflows( + From f, + bool strict_unsigned [[maybe_unused]] = false) { + using limit = std::numeric_limits::type>; + if (limit::has_infinity && std::isinf(static_cast(f))) { + return false; + } + if (!limit::has_quiet_NaN && (f != f)) { + return true; + } + return f < limit::lowest() || f > limit::max(); +} + +C10_CLANG_DIAGNOSTIC_POP() + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template +std::enable_if_t::value, bool> overflows( + From f, + bool strict_unsigned = false) { + // casts from complex to real are considered to overflow if the + // imaginary component is non-zero + if (!is_complex::value && f.imag() != 0) { + return true; + } + // Check for overflow componentwise + // (Technically, the imag overflow check is guaranteed to be false + // when !is_complex, but any optimizer worth its salt will be + // able to figure it out.) + return overflows< + typename scalar_value_type::type, + typename From::value_type>(f.real(), strict_unsigned) || + overflows< + typename scalar_value_type::type, + typename From::value_type>(f.imag(), strict_unsigned); +} +} // namespace c10 From 5d4060b6aa4567f5fc8e9bab64f1ca3e9c4b6f95 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 12:51:39 -0700 Subject: [PATCH 4/4] [ET][Portable] Check scalar overflow: op_fill Pull Request resolved: https://github.com/pytorch/executorch/pull/12012 ghstack-source-id: 292926439 @exported-using-ghexport Differential Revision: [D77382645](https://our.internmc.facebook.com/intern/diff/D77382645/) --- kernels/portable/cpu/op_fill.cpp | 4 +++- kernels/test/op_fill_test.cpp | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/kernels/portable/cpu/op_fill.cpp b/kernels/portable/cpu/op_fill.cpp index b985e2f4f07..8d98aa8bb7f 100644 --- a/kernels/portable/cpu/op_fill.cpp +++ b/kernels/portable/cpu/op_fill.cpp @@ -42,7 +42,9 @@ Tensor& fill_scalar_out( "Failed to resize output tensor."); ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] { - const CTYPE_A b_casted = utils::scalar_to(b); + auto opt_b_casted = utils::internal::check_overflow_scalar_cast(b); + ET_KERNEL_CHECK(ctx, opt_b_casted.has_value(), InvalidArgument, ); + auto b_casted = opt_b_casted.value(); apply_unary_map_fn( [b_casted](const CTYPE_A val_a) { return b_casted; }, diff --git a/kernels/test/op_fill_test.cpp b/kernels/test/op_fill_test.cpp index ac45ae307a5..0de49374477 100644 --- a/kernels/test/op_fill_test.cpp +++ b/kernels/test/op_fill_test.cpp @@ -74,6 +74,15 @@ class OpFillTest : public OperatorTest { // Check `out` matches expected output. EXPECT_TENSOR_EQ(out, exp_out); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor a = tf.ones({2, 2}); + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(a, bad_value, out)); + } }; // A macro for defining tests for both scalar and tensor variants of @@ -157,3 +166,28 @@ TEST_F(OpFillTest, MismatchedOutputDtypeDies) { // Assert `out` can't be filled due to incompatible dtype. ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(self, 0.0, out)); } + +TEST_F(OpFillTest, ByteTensorTooLargeScalarDies) { + // Cannot be represented by a uint8_t. + expect_bad_scalar_value_dies(256); +} + +TEST_F(OpFillTest, CharTensorTooSmallScalarDies) { + // Cannot be represented by a int8_t. + expect_bad_scalar_value_dies(-129); +} + +TEST_F(OpFillTest, ShortTensorTooLargeScalarDies) { + // Cannot be represented by a int16_t. + expect_bad_scalar_value_dies(32768); +} + +TEST_F(OpFillTest, FloatTensorTooSmallScalarDies) { + // Cannot be represented by a float. + expect_bad_scalar_value_dies(-3.41e+38); +} + +TEST_F(OpFillTest, FloatTensorTooLargeScalarDies) { + // Cannot be represented by a float. + expect_bad_scalar_value_dies(3.41e+38); +}