diff --git a/kernels/portable/cpu/op_scalar_tensor.cpp b/kernels/portable/cpu/op_scalar_tensor.cpp index e111a9ac869..bff4ecc318c 100644 --- a/kernels/portable/cpu/op_scalar_tensor.cpp +++ b/kernels/portable/cpu/op_scalar_tensor.cpp @@ -24,17 +24,11 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) { constexpr auto name = "scalar_tensor.out"; - if (s.isFloatingPoint() && - executorch::runtime::isIntegralType(out_type, false)) { - ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() { - out.mutable_data_ptr()[0] = - static_cast(utils::scalar_to(s)); - }); - } else { - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() { - out.mutable_data_ptr()[0] = utils::scalar_to(s); - }); - } + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() { + auto opt_val_casted = utils::internal::check_overflow_scalar_cast(s); + ET_KERNEL_CHECK(ctx, opt_val_casted.has_value(), InvalidArgument, ); + out.mutable_data_ptr()[0] = opt_val_casted.value(); + }); return out; } diff --git a/kernels/test/op_scalar_tensor_test.cpp b/kernels/test/op_scalar_tensor_test.cpp index db4816e8847..0be6f395eb0 100644 --- a/kernels/test/op_scalar_tensor_test.cpp +++ b/kernels/test/op_scalar_tensor_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -71,6 +72,14 @@ class OpScalarTensorOutTest : public OperatorTest { ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out)); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor out = tf.zeros({}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(bad_value, out)); + } }; #define GENERATE_TEST_0D(ctype, dtype) \ @@ -131,3 +140,5 @@ TEST_F(OpScalarTensorOutTest, HalfSupport) { op_scalar_tensor_out(INFINITY, out); EXPECT_TENSOR_CLOSE(out, tf.make({}, {INFINITY})); } + +GENERATE_SCALAR_OVERFLOW_TESTS(OpScalarTensorOutTest)