Skip to content

Commit 3426504

Browse files
Update on "[ET][Portable] Check scalar overflow: op_scatter"
Differential Revision: [D77401093](https://our.internmc.facebook.com/intern/diff/D77401093/) [ghstack-poisoned]
2 parents 0dd85d9 + 87a1345 commit 3426504

File tree

2 files changed

+1
-24
lines changed

2 files changed

+1
-24
lines changed

kernels/portable/cpu/op_leaky_relu.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,7 @@ Tensor& leaky_relu_out(
4444
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4545

4646
ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
47-
auto opt_negative_slope_casted =
48-
utils::internal::check_overflow_scalar_cast<CTYPE>(negative_slope);
49-
ET_KERNEL_CHECK(
50-
ctx, opt_negative_slope_casted.has_value(), InvalidArgument, );
51-
auto negative_slope_casted = opt_negative_slope_casted.value();
47+
const CTYPE negative_slope_casted = utils::scalar_to<CTYPE>(negative_slope);
5248

5349
apply_unary_map_fn(
5450
[negative_slope_casted](const CTYPE val_in) {

kernels/test/op_leaky_relu_test.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,10 @@ class OpLeakyReluTest : public OperatorTest {
4040
EXPECT_TENSOR_EQ(out, ret);
4141
EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
4242
}
43-
44-
template <ScalarType DTYPE>
45-
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
46-
TensorFactory<DTYPE> tf;
47-
Tensor in = tf.ones({2, 2});
48-
Tensor out = tf.zeros({2, 2});
49-
50-
ET_EXPECT_KERNEL_FAILURE(context_, op_leaky_relu_out(in, bad_value, out));
51-
}
5243
};
5344

5445
TEST_F(OpLeakyReluTest, SanityCheck) {
5546
#define TEST_ENTRY(ctype, dtype) test_leaky_relu_dtype<ScalarType::dtype>();
5647
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
5748
#undef TEST_ENTRY
5849
}
59-
60-
TEST_F(OpLeakyReluTest, FloatTensorTooSmallScalarDies) {
61-
/* Cannot be represented by a float. */
62-
expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38);
63-
}
64-
65-
TEST_F(OpLeakyReluTest, FloatTensorTooLargeScalarDies) {
66-
/* Cannot be represented by a float. */
67-
expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38);
68-
}

0 commit comments

Comments
 (0)