|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
|
| 10 | +#include <executorch/kernels/test/ScalarOverflowTestMacros.h> |
10 | 11 | #include <executorch/kernels/test/TestUtil.h>
|
11 | 12 | #include <executorch/runtime/core/exec_aten/exec_aten.h>
|
12 | 13 | #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
|
@@ -51,10 +52,27 @@ class OpHardTanhTest : public OperatorTest {
|
51 | 52 | EXPECT_TENSOR_EQ(out, ret);
|
52 | 53 | EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2}));
|
53 | 54 | }
|
| 55 | + |
| 56 | + template <ScalarType DTYPE> |
| 57 | + void expect_bad_scalar_value_dies(const Scalar& bad_value) { |
| 58 | + TensorFactory<DTYPE> tf; |
| 59 | + Tensor in = tf.ones({2, 2}); |
| 60 | + Tensor out = tf.zeros({2, 2}); |
| 61 | + |
| 62 | + // Test overflow for min parameter (using valid max) |
| 63 | + ET_EXPECT_KERNEL_FAILURE( |
| 64 | + context_, op_hardtanh_out(in, bad_value, 1.0, out)); |
| 65 | + |
| 66 | + // Test overflow for max parameter (using valid min) |
| 67 | + ET_EXPECT_KERNEL_FAILURE( |
| 68 | + context_, op_hardtanh_out(in, -1.0, bad_value, out)); |
| 69 | + } |
54 | 70 | };
|
55 | 71 |
|
56 | 72 | TEST_F(OpHardTanhTest, SanityCheck) {
|
57 | 73 | #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
|
58 | 74 | ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
|
59 | 75 | #undef TEST_ENTRY
|
60 | 76 | }
|
| 77 | + |
| 78 | +GENERATE_SCALAR_OVERFLOW_TESTS(OpHardTanhTest) |
0 commit comments