diff --git a/kernels/portable/cpu/op_hardtanh.cpp b/kernels/portable/cpu/op_hardtanh.cpp index 8ec73b07856..65411d5f6b0 100644 --- a/kernels/portable/cpu/op_hardtanh.cpp +++ b/kernels/portable/cpu/op_hardtanh.cpp @@ -45,8 +45,15 @@ Tensor& hardtanh_out( ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { - const CTYPE min_casted = utils::scalar_to(min); - const CTYPE max_casted = utils::scalar_to(max); + auto opt_min_casted = + utils::internal::check_overflow_scalar_cast(min); + ET_KERNEL_CHECK(ctx, opt_min_casted.has_value(), InvalidArgument, ); + auto min_casted = opt_min_casted.value(); + + auto opt_max_casted = + utils::internal::check_overflow_scalar_cast(max); + ET_KERNEL_CHECK(ctx, opt_max_casted.has_value(), InvalidArgument, ); + auto max_casted = opt_max_casted.value(); apply_unary_map_fn( [min_casted, max_casted](const CTYPE val_in) { diff --git a/kernels/test/op_hardtanh_test.cpp b/kernels/test/op_hardtanh_test.cpp index 72d09063d3e..38b0eeea40f 100644 --- a/kernels/test/op_hardtanh_test.cpp +++ b/kernels/test/op_hardtanh_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -51,6 +52,21 @@ class OpHardTanhTest : public OperatorTest { EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2})); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor in = tf.ones({2, 2}); + Tensor out = tf.zeros({2, 2}); + + // Test overflow for min parameter (using valid max) + ET_EXPECT_KERNEL_FAILURE( + context_, op_hardtanh_out(in, bad_value, 1.0, out)); + + // Test overflow for max parameter (using valid min) + ET_EXPECT_KERNEL_FAILURE( + context_, op_hardtanh_out(in, -1.0, bad_value, out)); + } }; TEST_F(OpHardTanhTest, SanityCheck) { @@ -58,3 +74,5 @@ TEST_F(OpHardTanhTest, SanityCheck) { ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } + +GENERATE_SCALAR_OVERFLOW_TESTS(OpHardTanhTest)