Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ Tensor& add_out(
static constexpr const char op_name[] = "add.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
CTYPE_COMPUTE val_alpha;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
Expand Down Expand Up @@ -103,7 +105,9 @@ Tensor& add_scalar_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
CTYPE_COMPUTE val_alpha;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
auto val_alpha_times_b = val_alpha * val_b;
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
Expand Down
68 changes: 66 additions & 2 deletions kernels/test/op_add_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/kernels/test/supported_features.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
Expand All @@ -15,8 +16,6 @@

#include <gtest/gtest.h>

#include <iostream>

using namespace ::testing;
using executorch::aten::Scalar;
using executorch::aten::ScalarType;
Expand Down Expand Up @@ -231,6 +230,27 @@ class OpAddOutKernelTest : public OperatorTest {
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
}

template <ScalarType DTYPE>
void expect_bad_alpha_value_dies(const Scalar& bad_value) {
TensorFactory<DTYPE> tf;
Tensor a = tf.ones({2, 2});
Tensor b = tf.ones({2, 2});
Tensor out = tf.zeros({2, 2});

ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, bad_value, out));
}

// The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
// test cases requires a method called expect_bad_scalar_value_dies. However,
// for add operation, these checks only apply to the alpha argument.
// We are being explicit about this by naming the above function
// expect_bad_alpha_value_dies, and creating this wrapper in order to use the
// macro.
template <ScalarType DTYPE>
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
expect_bad_alpha_value_dies<DTYPE>(bad_value);
}
};

class OpAddScalarOutKernelTest : public OperatorTest {
Expand All @@ -242,6 +262,27 @@ class OpAddScalarOutKernelTest : public OperatorTest {
Tensor& out) {
return torch::executor::aten::add_outf(context_, self, other, alpha, out);
}

template <ScalarType DTYPE>
void expect_bad_alpha_value_dies(const Scalar& bad_value) {
TensorFactory<DTYPE> tf;
Tensor a = tf.ones({2, 2});
Scalar b = 1;
Tensor out = tf.zeros({2, 2});

ET_EXPECT_KERNEL_FAILURE(context_, op_add_scalar_out(a, b, bad_value, out));
}

// The GENERATE_SCALAR_OVERFLOW_TESTS macro used to generate scalar overflow
// test cases requires a method called expect_bad_scalar_value_dies. However,
// for the add operation, these checks only apply to the alpha argument.
// We are being explicit about this by naming the above function
// expect_bad_alpha_value_dies, and creating this wrapper in order to use the
// macro.
template <ScalarType DTYPE>
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
expect_bad_alpha_value_dies<DTYPE>(bad_value);
}
};

/**
Expand Down Expand Up @@ -794,3 +835,26 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
op_add_scalar_out(self, other, alpha, out);
EXPECT_TENSOR_CLOSE(out, out_expected);
}

TEST_F(OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint8_t.
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
}

TEST_F(OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
}

TEST_F(OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint8_t.
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
}

TEST_F(OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
}

GENERATE_SCALAR_OVERFLOW_TESTS(OpAddOutKernelTest)
GENERATE_SCALAR_OVERFLOW_TESTS(OpAddScalarOutKernelTest)
Loading