Skip to content

Commit b77597a

Browse files
Update on "[ET][Portable] Check scalar overflow: op_leaky_relu"
Differential Revision: [D77401091](https://our.internmc.facebook.com/intern/diff/D77401091/) [ghstack-poisoned]
2 parents 65879d0 + d6a35bc commit b77597a

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

kernels/portable/cpu/op_hardtanh.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ Tensor& hardtanh_out(
4545
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4646

4747
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
48-
const CTYPE min_casted = utils::scalar_to<CTYPE>(min);
49-
const CTYPE max_casted = utils::scalar_to<CTYPE>(max);
48+
auto opt_min_casted =
49+
utils::internal::check_overflow_scalar_cast<CTYPE>(min);
50+
ET_KERNEL_CHECK(ctx, opt_min_casted.has_value(), InvalidArgument, );
51+
auto min_casted = opt_min_casted.value();
52+
53+
auto opt_max_casted =
54+
utils::internal::check_overflow_scalar_cast<CTYPE>(max);
55+
ET_KERNEL_CHECK(ctx, opt_max_casted.has_value(), InvalidArgument, );
56+
auto max_casted = opt_max_casted.value();
5057

5158
apply_unary_map_fn(
5259
[min_casted, max_casted](const CTYPE val_in) {

kernels/portable/cpu/op_scatter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ Tensor& scatter_value_out(
154154
constexpr auto name = "scatter.value_out";
155155

156156
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
157-
const CTYPE val = utils::scalar_to<CTYPE>(value);
157+
auto opt_val = utils::internal::check_overflow_scalar_cast<CTYPE>(value);
158+
ET_KERNEL_CHECK(ctx, opt_val.has_value(), InvalidArgument, );
159+
auto val = opt_val.value();
158160
scatter_value_helper<CTYPE>(in, dim, index, val, out);
159161
});
160162

kernels/test/op_hardtanh_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -51,10 +52,27 @@ class OpHardTanhTest : public OperatorTest {
5152
EXPECT_TENSOR_EQ(out, ret);
5253
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2}));
5354
}
55+
56+
template <ScalarType DTYPE>
57+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
58+
TensorFactory<DTYPE> tf;
59+
Tensor in = tf.ones({2, 2});
60+
Tensor out = tf.zeros({2, 2});
61+
62+
// Test overflow for min parameter (using valid max)
63+
ET_EXPECT_KERNEL_FAILURE(
64+
context_, op_hardtanh_out(in, bad_value, 1.0, out));
65+
66+
// Test overflow for max parameter (using valid min)
67+
ET_EXPECT_KERNEL_FAILURE(
68+
context_, op_hardtanh_out(in, -1.0, bad_value, out));
69+
}
5470
};
5571

5672
TEST_F(OpHardTanhTest, SanityCheck) {
5773
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
5874
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
5975
#undef TEST_ENTRY
6076
}
77+
78+
GENERATE_SCALAR_OVERFLOW_TESTS(OpHardTanhTest)

kernels/test/op_scatter_test.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -364,6 +365,19 @@ class OpScatterValueOutTest : public OperatorTest {
364365
op_scatter_value_out(input, 2, index, value, out);
365366
EXPECT_TENSOR_EQ(out, expected);
366367
}
368+
369+
template <ScalarType DTYPE>
370+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
371+
TensorFactory<DTYPE> tf;
372+
TensorFactory<ScalarType::Long> tf_index;
373+
374+
Tensor self = tf.ones({2, 2});
375+
Tensor index = tf_index.zeros({2, 2});
376+
Tensor out = tf.zeros({2, 2});
377+
378+
ET_EXPECT_KERNEL_FAILURE(
379+
context_, op_scatter_value_out(self, 0, index, bad_value, out));
380+
}
367381
};
368382

369383
TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) {
@@ -652,3 +666,5 @@ TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) {
652666
ET_EXPECT_KERNEL_FAILURE(
653667
context_, op_scatter_src_out(self, 0, index, src, out));
654668
}
669+
670+
GENERATE_SCALAR_OVERFLOW_TESTS(OpScatterValueOutTest)

0 commit comments

Comments
 (0)