77 */
88
99#include < executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+ #include < executorch/kernels/test/ScalarOverflowTestMacros.h>
1011#include < executorch/kernels/test/TestUtil.h>
1112#include < executorch/kernels/test/supported_features.h>
1213#include < executorch/runtime/core/exec_aten/exec_aten.h>
@@ -59,6 +60,17 @@ class OpFullOutTest : public OperatorTest {
5960 op_full_out (aref, 1.0 , out);
6061 EXPECT_TENSOR_EQ (out, tf.ones (size_int32_t ));
6162 }
63+
64+ template <ScalarType DTYPE>
65+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
66+ TensorFactory<DTYPE> tf;
67+ std::vector<int32_t > sizes = {2 , 2 };
68+ std::vector<int64_t > sizes_int64_t (sizes.begin (), sizes.end ());
69+ auto aref = IntArrayRef (sizes_int64_t .data (), sizes_int64_t .size ());
70+ Tensor out = tf.zeros (sizes);
71+
72+ ET_EXPECT_KERNEL_FAILURE (context_, op_full_out (aref, bad_value, out));
73+ }
6274};
6375
6476#define GENERATE_TEST (_, DTYPE ) \
@@ -72,20 +84,7 @@ class OpFullOutTest : public OperatorTest {
7284
7385ET_FORALL_REALHBF16_TYPES (GENERATE_TEST)
7486
75- TEST_F(OpFullOutTest, ValueOverflow) {
76- if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
77- GTEST_SKIP () << " ATen kernel doesn't handle overflow" ;
78- }
79- TensorFactory<ScalarType::Byte> tf;
80-
81- std::vector<int64_t > sizes_int64_t_vec = {2 , 3 };
82- std::vector<int32_t > sizes_in32_t_vec = {2 , 3 };
83- auto sizes = IntArrayRef (sizes_int64_t_vec.data (), sizes_int64_t_vec.size ());
84-
85- Tensor out = tf.zeros (sizes_in32_t_vec);
86-
87- op_full_out (sizes, 1000 , out);
88- }
87+ GENERATE_SCALAR_OVERFLOW_TESTS(OpFullOutTest)
8988
9089TEST_F(OpFullOutTest, HalfSupport) {
9190 TensorFactory<ScalarType::Half> tf;
0 commit comments