Skip to content

Commit 5fe53c3

Browse files
[ET][Portable] Check scalar overflow: op_hardtanh (#12085)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12029 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/127/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/127/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/126/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/127/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent c1f4d59 commit 5fe53c3

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

kernels/portable/cpu/op_hardtanh.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ Tensor& hardtanh_out(
4545
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4646

4747
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
48-
const CTYPE min_casted = utils::scalar_to<CTYPE>(min);
49-
const CTYPE max_casted = utils::scalar_to<CTYPE>(max);
48+
auto opt_min_casted =
49+
utils::internal::check_overflow_scalar_cast<CTYPE>(min);
50+
ET_KERNEL_CHECK(ctx, opt_min_casted.has_value(), InvalidArgument, );
51+
auto min_casted = opt_min_casted.value();
52+
53+
auto opt_max_casted =
54+
utils::internal::check_overflow_scalar_cast<CTYPE>(max);
55+
ET_KERNEL_CHECK(ctx, opt_max_casted.has_value(), InvalidArgument, );
56+
auto max_casted = opt_max_casted.value();
5057

5158
apply_unary_map_fn(
5259
[min_casted, max_casted](const CTYPE val_in) {

kernels/test/op_hardtanh_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -51,10 +52,27 @@ class OpHardTanhTest : public OperatorTest {
5152
EXPECT_TENSOR_EQ(out, ret);
5253
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2}));
5354
}
55+
56+
template <ScalarType DTYPE>
57+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
58+
TensorFactory<DTYPE> tf;
59+
Tensor in = tf.ones({2, 2});
60+
Tensor out = tf.zeros({2, 2});
61+
62+
// Test overflow for min parameter (using valid max)
63+
ET_EXPECT_KERNEL_FAILURE(
64+
context_, op_hardtanh_out(in, bad_value, 1.0, out));
65+
66+
// Test overflow for max parameter (using valid min)
67+
ET_EXPECT_KERNEL_FAILURE(
68+
context_, op_hardtanh_out(in, -1.0, bad_value, out));
69+
}
5470
};
5571

5672
TEST_F(OpHardTanhTest, SanityCheck) {
5773
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
5874
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
5975
#undef TEST_ENTRY
6076
}
77+
78+
GENERATE_SCALAR_OVERFLOW_TESTS(OpHardTanhTest)

0 commit comments

Comments
 (0)