Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion kernels/portable/cpu/op_leaky_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ Tensor& leaky_relu_out(
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);

ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
const CTYPE negative_slope_casted = utils::scalar_to<CTYPE>(negative_slope);
auto opt_negative_slope_casted =
utils::internal::check_overflow_scalar_cast<CTYPE>(negative_slope);
ET_KERNEL_CHECK(
ctx, opt_negative_slope_casted.has_value(), InvalidArgument, );
auto negative_slope_casted = opt_negative_slope_casted.value();

apply_unary_map_fn(
[negative_slope_casted](const CTYPE val_in) {
Expand Down
19 changes: 19 additions & 0 deletions kernels/test/op_leaky_relu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,29 @@ class OpLeakyReluTest : public OperatorTest {
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
}

template <ScalarType DTYPE>
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
TensorFactory<DTYPE> tf;
Tensor in = tf.ones({2, 2});
Tensor out = tf.zeros({2, 2});

ET_EXPECT_KERNEL_FAILURE(context_, op_leaky_relu_out(in, bad_value, out));
}
};

TEST_F(OpLeakyReluTest, SanityCheck) {
#define TEST_ENTRY(ctype, dtype) test_leaky_relu_dtype<ScalarType::dtype>();
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

TEST_F(OpLeakyReluTest, FloatTensorTooSmallScalarDies) {
/* Cannot be represented by a float. */
expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38);
}

TEST_F(OpLeakyReluTest, FloatTensorTooLargeScalarDies) {
/* Cannot be represented by a float. */
expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38);
}
Loading