diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index 7de0ec4d5f9..965afbb4b66 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -154,7 +154,9 @@ Tensor& scatter_value_out( constexpr auto name = "scatter.value_out"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { - const CTYPE val = utils::scalar_to(value); + auto opt_val = utils::internal::check_overflow_scalar_cast(value); + ET_KERNEL_CHECK(ctx, opt_val.has_value(), InvalidArgument, ); + auto val = opt_val.value(); scatter_value_helper(in, dim, index, val, out); }); diff --git a/kernels/test/op_scatter_test.cpp b/kernels/test/op_scatter_test.cpp index 0e55aadaeda..dac9017d188 100644 --- a/kernels/test/op_scatter_test.cpp +++ b/kernels/test/op_scatter_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -364,6 +365,19 @@ class OpScatterValueOutTest : public OperatorTest { op_scatter_value_out(input, 2, index, value, out); EXPECT_TENSOR_EQ(out, expected); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + TensorFactory tf_index; + + Tensor self = tf.ones({2, 2}); + Tensor index = tf_index.zeros({2, 2}); + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_value_out(self, 0, index, bad_value, out)); + } }; TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) { @@ -652,3 +666,5 @@ TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) { ET_EXPECT_KERNEL_FAILURE( context_, op_scatter_src_out(self, 0, index, src, out)); } + +GENERATE_SCALAR_OVERFLOW_TESTS(OpScatterValueOutTest)