Skip to content

Commit 0bd1ef6

Browse files
[ET][Portable] Check scalar overflow: op_leaky_relu (#12086)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12027 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/125/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/125/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/127/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/125/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent 5fe53c3 commit 0bd1ef6

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

kernels/portable/cpu/op_leaky_relu.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ Tensor& leaky_relu_out(
4444
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4545

4646
ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
47-
const CTYPE negative_slope_casted = utils::scalar_to<CTYPE>(negative_slope);
47+
auto opt_negative_slope_casted =
48+
utils::internal::check_overflow_scalar_cast<CTYPE>(negative_slope);
49+
ET_KERNEL_CHECK(
50+
ctx, opt_negative_slope_casted.has_value(), InvalidArgument, );
51+
auto negative_slope_casted = opt_negative_slope_casted.value();
4852

4953
apply_unary_map_fn(
5054
[negative_slope_casted](const CTYPE val_in) {

kernels/test/op_leaky_relu_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,29 @@ class OpLeakyReluTest : public OperatorTest {
4040
EXPECT_TENSOR_EQ(out, ret);
4141
EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
4242
}
43+
44+
template <ScalarType DTYPE>
45+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
46+
TensorFactory<DTYPE> tf;
47+
Tensor in = tf.ones({2, 2});
48+
Tensor out = tf.zeros({2, 2});
49+
50+
ET_EXPECT_KERNEL_FAILURE(context_, op_leaky_relu_out(in, bad_value, out));
51+
}
4352
};
4453

4554
TEST_F(OpLeakyReluTest, SanityCheck) {
4655
#define TEST_ENTRY(ctype, dtype) test_leaky_relu_dtype<ScalarType::dtype>();
4756
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
4857
#undef TEST_ENTRY
4958
}
59+
60+
TEST_F(OpLeakyReluTest, FloatTensorTooSmallScalarDies) {
61+
/* Cannot be represented by a float. */
62+
expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38);
63+
}
64+
65+
TEST_F(OpLeakyReluTest, FloatTensorTooLargeScalarDies) {
66+
/* Cannot be represented by a float. */
67+
expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38);
68+
}

0 commit comments

Comments
 (0)