diff --git a/kernels/portable/cpu/op_fill.cpp b/kernels/portable/cpu/op_fill.cpp index b985e2f4f07..8d98aa8bb7f 100644 --- a/kernels/portable/cpu/op_fill.cpp +++ b/kernels/portable/cpu/op_fill.cpp @@ -42,7 +42,9 @@ Tensor& fill_scalar_out( "Failed to resize output tensor."); ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] { - const CTYPE_A b_casted = utils::scalar_to(b); + auto opt_b_casted = utils::internal::check_overflow_scalar_cast(b); + ET_KERNEL_CHECK(ctx, opt_b_casted.has_value(), InvalidArgument, ); + auto b_casted = opt_b_casted.value(); apply_unary_map_fn( [b_casted](const CTYPE_A val_a) { return b_casted; }, diff --git a/kernels/test/op_fill_test.cpp b/kernels/test/op_fill_test.cpp index ac45ae307a5..0de49374477 100644 --- a/kernels/test/op_fill_test.cpp +++ b/kernels/test/op_fill_test.cpp @@ -74,6 +74,15 @@ class OpFillTest : public OperatorTest { // Check `out` matches expected output. EXPECT_TENSOR_EQ(out, exp_out); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor a = tf.ones({2, 2}); + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(a, bad_value, out)); + } }; // A macro for defining tests for both scalar and tensor variants of @@ -157,3 +166,28 @@ TEST_F(OpFillTest, MismatchedOutputDtypeDies) { // Assert `out` can't be filled due to incompatible dtype. ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(self, 0.0, out)); } + +TEST_F(OpFillTest, ByteTensorTooLargeScalarDies) { + // Cannot be represented by a uint8_t. + expect_bad_scalar_value_dies(256); +} + +TEST_F(OpFillTest, CharTensorTooSmallScalarDies) { + // Cannot be represented by a int8_t. + expect_bad_scalar_value_dies(-129); +} + +TEST_F(OpFillTest, ShortTensorTooLargeScalarDies) { + // Cannot be represented by a int16_t. + expect_bad_scalar_value_dies(32768); +} + +TEST_F(OpFillTest, FloatTensorTooSmallScalarDies) { + // Cannot be represented by a float. + expect_bad_scalar_value_dies(-3.41e+38); +} + +TEST_F(OpFillTest, FloatTensorTooLargeScalarDies) { + // Cannot be represented by a float. + expect_bad_scalar_value_dies(3.41e+38); +}