diff --git a/kernels/portable/cpu/op_leaky_relu.cpp b/kernels/portable/cpu/op_leaky_relu.cpp index 11860c8d129..fa62a75974e 100644 --- a/kernels/portable/cpu/op_leaky_relu.cpp +++ b/kernels/portable/cpu/op_leaky_relu.cpp @@ -44,7 +44,11 @@ Tensor& leaky_relu_out( ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() { - const CTYPE negative_slope_casted = utils::scalar_to(negative_slope); + auto opt_negative_slope_casted = + utils::internal::check_overflow_scalar_cast(negative_slope); + ET_KERNEL_CHECK( + ctx, opt_negative_slope_casted.has_value(), InvalidArgument, ); + auto negative_slope_casted = opt_negative_slope_casted.value(); apply_unary_map_fn( [negative_slope_casted](const CTYPE val_in) { diff --git a/kernels/test/op_leaky_relu_test.cpp b/kernels/test/op_leaky_relu_test.cpp index 847c00652be..6b2e3083e2e 100644 --- a/kernels/test/op_leaky_relu_test.cpp +++ b/kernels/test/op_leaky_relu_test.cpp @@ -40,6 +40,15 @@ class OpLeakyReluTest : public OperatorTest { EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, tf.ones({2, 2})); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor in = tf.ones({2, 2}); + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_leaky_relu_out(in, bad_value, out)); + } }; TEST_F(OpLeakyReluTest, SanityCheck) { @@ -47,3 +56,13 @@ TEST_F(OpLeakyReluTest, SanityCheck) { ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } + +TEST_F(OpLeakyReluTest, FloatTensorTooSmallScalarDies) { + /* Cannot be represented by a float. */ + expect_bad_scalar_value_dies(-3.41e+38); +} + +TEST_F(OpLeakyReluTest, FloatTensorTooLargeScalarDies) { + /* Cannot be represented by a float. */ + expect_bad_scalar_value_dies(3.41e+38); +}