diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 83642c4864d..e10534cd233 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -51,7 +51,9 @@ Tensor& add_out( static constexpr const char op_name[] = "add.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + CTYPE_COMPUTE val_alpha; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, ); utils::apply_bitensor_elementwise_fn< CTYPE_COMPUTE, op_name, @@ -103,7 +105,9 @@ Tensor& add_scalar_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { CTYPE_COMPUTE val_b = utils::scalar_to(b); - CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + CTYPE_COMPUTE val_alpha; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, ); auto val_alpha_times_b = val_alpha * val_b; utils::apply_unitensor_elementwise_fn< CTYPE_COMPUTE, diff --git a/kernels/test/op_add_test.cpp b/kernels/test/op_add_test.cpp index c84341aa9b1..8af693e1b3e 100644 --- a/kernels/test/op_add_test.cpp +++ b/kernels/test/op_add_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -15,8 +16,6 @@ #include -#include - using namespace ::testing; using executorch::aten::Scalar; using executorch::aten::ScalarType; @@ -231,6 +230,27 @@ class OpAddOutKernelTest : public OperatorTest { EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected); EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected); } + + template + void expect_bad_alpha_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor a = tf.ones({2, 2}); + Tensor b = tf.ones({2, 2}); + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, bad_value, out)); + } + + // The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow + // test cases requires a method called expect_bad_scalar_value_dies. However, + // for add operation, these checks only apply to the alpha argument. + // We are being explicit about this by naming the above function + // expect_bad_alpha_value_dies, and creating this wrapper in order to use the + // macro. + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + expect_bad_alpha_value_dies(bad_value); + } }; class OpAddScalarOutKernelTest : public OperatorTest { @@ -242,6 +262,27 @@ class OpAddScalarOutKernelTest : public OperatorTest { Tensor& out) { return torch::executor::aten::add_outf(context_, self, other, alpha, out); } + + template + void expect_bad_alpha_value_dies(const Scalar& bad_value) { + TensorFactory tf; + Tensor a = tf.ones({2, 2}); + Scalar b = 1; + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_add_scalar_out(a, b, bad_value, out)); + } + + // The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow + // test cases requires a method called expect_bad_scalar_value_dies. However, + // for the add operation, these checks only apply to the alpha argument. + // We are being explicit about this by naming the above function + // expect_bad_alpha_value_dies, and creating this wrapper in order to use the + // macro. + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + expect_bad_alpha_value_dies(bad_value); + } }; /** @@ -794,3 +835,26 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) { op_add_scalar_out(self, other, alpha, out); EXPECT_TENSOR_CLOSE(out, out_expected); } + +TEST_F(OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) { + // Cannot be represented by a uint8_t. + expect_bad_alpha_value_dies(2.2); +} + +TEST_F(OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) { + // Cannot be represented by a uint32_t. + expect_bad_alpha_value_dies(2.2); +} + +TEST_F(OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) { + // Cannot be represented by a uint8_t. + expect_bad_alpha_value_dies(2.2); +} + +TEST_F(OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) { + // Cannot be represented by a uint32_t. + expect_bad_alpha_value_dies(2.2); +} + +GENERATE_SCALAR_OVERFLOW_TESTS(OpAddOutKernelTest) +GENERATE_SCALAR_OVERFLOW_TESTS(OpAddScalarOutKernelTest)