Skip to content

Commit 057ba65

Browse files
[ET][Portable] Check scalar overflow: op_hardtanh
Pull Request resolved: #12029 ghstack-source-id: 293086875 @exported-using-ghexport Differential Revision: [D77401090](https://our.internmc.facebook.com/intern/diff/D77401090/)
1 parent 447376d commit 057ba65

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

kernels/portable/cpu/op_hardtanh.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ Tensor& hardtanh_out(
4545
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4646

4747
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
48-
const CTYPE min_casted = utils::scalar_to<CTYPE>(min);
49-
const CTYPE max_casted = utils::scalar_to<CTYPE>(max);
48+
auto opt_min_casted =
49+
utils::internal::check_overflow_scalar_cast<CTYPE>(min);
50+
ET_KERNEL_CHECK(ctx, opt_min_casted.has_value(), InvalidArgument, );
51+
auto min_casted = opt_min_casted.value();
52+
53+
auto opt_max_casted =
54+
utils::internal::check_overflow_scalar_cast<CTYPE>(max);
55+
ET_KERNEL_CHECK(ctx, opt_max_casted.has_value(), InvalidArgument, );
56+
auto max_casted = opt_max_casted.value();
5057

5158
apply_unary_map_fn(
5259
[min_casted, max_casted](const CTYPE val_in) {

kernels/test/op_hardtanh_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -51,10 +52,27 @@ class OpHardTanhTest : public OperatorTest {
5152
EXPECT_TENSOR_EQ(out, ret);
5253
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2}));
5354
}
55+
56+
template <ScalarType DTYPE>
57+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
58+
TensorFactory<DTYPE> tf;
59+
Tensor in = tf.ones({2, 2});
60+
Tensor out = tf.zeros({2, 2});
61+
62+
// Test overflow for min parameter (using valid max)
63+
ET_EXPECT_KERNEL_FAILURE(
64+
context_, op_hardtanh_out(in, bad_value, 1.0, out));
65+
66+
// Test overflow for max parameter (using valid min)
67+
ET_EXPECT_KERNEL_FAILURE(
68+
context_, op_hardtanh_out(in, -1.0, bad_value, out));
69+
}
5470
};
5571

5672
TEST_F(OpHardTanhTest, SanityCheck) {
5773
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
5874
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
5975
#undef TEST_ENTRY
6076
}
77+
78+
GENERATE_SCALAR_OVERFLOW_TESTS(OpHardTanhTest)

0 commit comments

Comments
 (0)