7
7
*/
8
8
9
9
#include < executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10
+ #include < executorch/kernels/test/ScalarOverflowTestMacros.h>
10
11
#include < executorch/kernels/test/TestUtil.h>
11
12
#include < executorch/kernels/test/supported_features.h>
12
13
#include < executorch/runtime/core/exec_aten/exec_aten.h>
@@ -59,6 +60,17 @@ class OpFullOutTest : public OperatorTest {
59
60
op_full_out (aref, 1.0 , out);
60
61
EXPECT_TENSOR_EQ (out, tf.ones (size_int32_t ));
61
62
}
63
+
64
+ template <ScalarType DTYPE>
65
+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
66
+ TensorFactory<DTYPE> tf;
67
+ std::vector<int32_t > sizes = {2 , 2 };
68
+ std::vector<int64_t > sizes_int64_t (sizes.begin (), sizes.end ());
69
+ auto aref = IntArrayRef (sizes_int64_t .data (), sizes_int64_t .size ());
70
+ Tensor out = tf.zeros (sizes);
71
+
72
+ ET_EXPECT_KERNEL_FAILURE (context_, op_full_out (aref, bad_value, out));
73
+ }
62
74
};
63
75
64
76
#define GENERATE_TEST (_, DTYPE ) \
@@ -72,20 +84,7 @@ class OpFullOutTest : public OperatorTest {
72
84
73
85
ET_FORALL_REALHBF16_TYPES (GENERATE_TEST)
74
86
75
- TEST_F(OpFullOutTest, ValueOverflow) {
76
- if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
77
- GTEST_SKIP () << " ATen kernel doesn't handle overflow" ;
78
- }
79
- TensorFactory<ScalarType::Byte> tf;
80
-
81
- std::vector<int64_t > sizes_int64_t_vec = {2 , 3 };
82
- std::vector<int32_t > sizes_in32_t_vec = {2 , 3 };
83
- auto sizes = IntArrayRef (sizes_int64_t_vec.data (), sizes_int64_t_vec.size ());
84
-
85
- Tensor out = tf.zeros (sizes_in32_t_vec);
86
-
87
- op_full_out (sizes, 1000 , out);
88
- }
87
+ GENERATE_SCALAR_OVERFLOW_TESTS(OpFullOutTest)
89
88
90
89
TEST_F(OpFullOutTest, HalfSupport) {
91
90
TensorFactory<ScalarType::Half> tf;
0 commit comments