diff --git a/kernels/portable/cpu/op_constant_pad_nd.cpp b/kernels/portable/cpu/op_constant_pad_nd.cpp index 6e643e1b945..be3962e018c 100644 --- a/kernels/portable/cpu/op_constant_pad_nd.cpp +++ b/kernels/portable/cpu/op_constant_pad_nd.cpp @@ -185,7 +185,10 @@ Tensor& constant_pad_nd_out( ScalarType in_type = in.scalar_type(); ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() { - const CTYPE value_casted = utils::scalar_to(value); + auto opt_value_casted = + utils::internal::check_overflow_scalar_cast(value); + ET_KERNEL_CHECK(ctx, opt_value_casted.has_value(), InvalidArgument, ); + auto value_casted = opt_value_casted.value(); constant_pad_nd_out_impl(in, pad, value_casted, out); }); diff --git a/kernels/test/op_constant_pad_nd_test.cpp b/kernels/test/op_constant_pad_nd_test.cpp index 88bee1d0ad9..7f44068d9cb 100644 --- a/kernels/test/op_constant_pad_nd_test.cpp +++ b/kernels/test/op_constant_pad_nd_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -347,6 +348,21 @@ class OpConstantPadNDOutTest : public OperatorTest { op_constant_pad_nd_out(self, padding_ref, 7, out); EXPECT_TENSOR_CLOSE(out, expected); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + const std::vector sizes = {2, 2}; + const std::vector sizes_out = {2, 4}; + const std::vector padding = {1, 1}; + + IntArrayRef padding_ref = IntArrayRef(padding.data(), padding.size()); + Tensor self = tf.ones(sizes); + Tensor out = tf.zeros(sizes_out); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_constant_pad_nd_out(self, padding_ref, bad_value, out)); + } }; TEST_F(OpConstantPadNDOutTest, TestPadDim2) { @@ -465,3 +481,5 @@ TEST_F(OpConstantPadNDOutTest, IncorrectOutputShapeFail) { ET_EXPECT_KERNEL_FAILURE( context_, op_constant_pad_nd_out(self, padding_ref, 0, out)); } + +GENERATE_SCALAR_OVERFLOW_TESTS(OpConstantPadNDOutTest)