77 */
88
99#include < executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+ #include < executorch/kernels/test/ScalarOverflowTestMacros.h>
1011#include < executorch/kernels/test/TestUtil.h>
1112#include < executorch/kernels/test/supported_features.h>
1213#include < executorch/runtime/core/exec_aten/exec_aten.h>
1516
1617#include < gtest/gtest.h>
1718
18- #include < iostream>
19-
2019using namespace ::testing;
2120using executorch::aten::Scalar;
2221using executorch::aten::ScalarType;
@@ -231,6 +230,27 @@ class OpAddOutKernelTest : public OperatorTest {
231230 EXPECT_TENSOR_CLOSE (op_add_out (a, b, 1.0 , out), expected);
232231 EXPECT_TENSOR_CLOSE (op_add_out (b, a, 1.0 , out), expected);
233232 }
233+
234+ template <ScalarType DTYPE>
235+ void expect_bad_alpha_value_dies (const Scalar& bad_value) {
236+ TensorFactory<DTYPE> tf;
237+ Tensor a = tf.ones ({2 , 2 });
238+ Tensor b = tf.ones ({2 , 2 });
239+ Tensor out = tf.zeros ({2 , 2 });
240+
241+ ET_EXPECT_KERNEL_FAILURE (context_, op_add_out (a, b, bad_value, out));
242+ }
243+
244+ // The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
245+ // test cases requires a method called expect_bad_scalar_value_dies. However,
246+ // for add operation, these checks only apply to the alpha argument.
247+ // We are being explicit about this by naming the above function
248+ // expect_bad_alpha_value_dies, and creating this wrapper in order to use the
249+ // macro.
250+ template <ScalarType DTYPE>
251+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
252+ expect_bad_alpha_value_dies<DTYPE>(bad_value);
253+ }
234254};
235255
236256class OpAddScalarOutKernelTest : public OperatorTest {
@@ -242,6 +262,27 @@ class OpAddScalarOutKernelTest : public OperatorTest {
242262 Tensor& out) {
243263 return torch::executor::aten::add_outf (context_, self, other, alpha, out);
244264 }
265+
266+ template <ScalarType DTYPE>
267+ void expect_bad_alpha_value_dies (const Scalar& bad_value) {
268+ TensorFactory<DTYPE> tf;
269+ Tensor a = tf.ones ({2 , 2 });
270+ Scalar b = 1 ;
271+ Tensor out = tf.zeros ({2 , 2 });
272+
273+ ET_EXPECT_KERNEL_FAILURE (context_, op_add_scalar_out (a, b, bad_value, out));
274+ }
275+
276+ // The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
277+ // test cases requires a method called expect_bad_scalar_value_dies. However,
278+ // for the add operation, these checks only apply to the alpha argument.
279+ // We are being explicit about this by naming the above function
280+ // expect_bad_alpha_value_dies, and creating this wrapper in order to use the
281+ // macro.
282+ template <ScalarType DTYPE>
283+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
284+ expect_bad_alpha_value_dies<DTYPE>(bad_value);
285+ }
245286};
246287
247288/* *
@@ -794,3 +835,26 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
794835 op_add_scalar_out (self, other, alpha, out);
795836 EXPECT_TENSOR_CLOSE (out, out_expected);
796837}
838+
839+ TEST_F (OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
840+ // Cannot be represented by a uint8_t.
841+ expect_bad_alpha_value_dies<ScalarType::Byte>(2.2 );
842+ }
843+
844+ TEST_F (OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
845+ // Cannot be represented by a uint32_t.
846+ expect_bad_alpha_value_dies<ScalarType::Int>(2.2 );
847+ }
848+
849+ TEST_F (OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
850+ // Cannot be represented by a uint8_t.
851+ expect_bad_alpha_value_dies<ScalarType::Byte>(2.2 );
852+ }
853+
854+ TEST_F (OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
855+ // Cannot be represented by a uint32_t.
856+ expect_bad_alpha_value_dies<ScalarType::Int>(2.2 );
857+ }
858+
859+ GENERATE_SCALAR_OVERFLOW_TESTS (OpAddOutKernelTest)
860+ GENERATE_SCALAR_OVERFLOW_TESTS(OpAddScalarOutKernelTest)
0 commit comments