Skip to content

Commit 1f3a90c

Browse files
[ET][Portable] Test bad alpha values: op_add (#11982)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #11982 * #11980 - Tests for bad alpha values: add.Tensor & add.Scalar - Fixes alpha handling in portable/optimized op_add kernels - Eliminates usage of ET_SWITCH_SCALAR_OBJ_TYPES in optimized op_add kernel Differential Revision: [D77325770](https://our.internmc.facebook.com/intern/diff/D77325770/)
1 parent b27ed23 commit 1f3a90c

File tree

3 files changed

+151
-46
lines changed

3 files changed

+151
-46
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ Tensor& opt_add_out(
4545
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
4646
CTYPE alpha_val;
4747
ET_KERNEL_CHECK(
48-
ctx,
49-
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
50-
InvalidArgument, );
48+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
5149
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
5250
CTYPE b_casted = static_cast<CTYPE>(b_val);
5351

@@ -81,7 +79,6 @@ Tensor& opt_add_scalar_out(
8179
(void)ctx;
8280

8381
ScalarType a_type = a.scalar_type();
84-
ScalarType b_type = utils::get_scalar_dtype(b);
8582
ScalarType common_type =
8683
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
8784
ScalarType out_type = out.scalar_type();
@@ -99,47 +96,43 @@ Tensor& opt_add_scalar_out(
9996
if (a_type == common_type && a_type == out_type &&
10097
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
10198
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
102-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
103-
CTYPE_B b_val;
104-
ET_EXTRACT_SCALAR(b, b_val);
105-
CTYPE b_casted = static_cast<CTYPE>(b_val);
106-
CTYPE alpha_val;
107-
ET_EXTRACT_SCALAR(alpha, alpha_val);
108-
109-
using Vec = at::vec::Vectorized<CTYPE>;
110-
at::vec::map<CTYPE>(
111-
[alpha_val, b_casted](Vec x) {
112-
return x + Vec(alpha_val * b_casted);
113-
},
114-
out.mutable_data_ptr<CTYPE>(),
115-
a.const_data_ptr<CTYPE>(),
116-
out.numel());
117-
});
99+
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
100+
CTYPE alpha_val;
101+
ET_KERNEL_CHECK(
102+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
103+
104+
using Vec = at::vec::Vectorized<CTYPE>;
105+
at::vec::map<CTYPE>(
106+
[alpha_val, b_casted](Vec x) {
107+
return x + Vec(alpha_val * b_casted);
108+
},
109+
out.mutable_data_ptr<CTYPE>(),
110+
a.const_data_ptr<CTYPE>(),
111+
out.numel());
118112
});
119113
} else {
120114
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
121-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
122-
ET_SWITCH_REALB_TYPES(
123-
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
124-
ET_SWITCH_REALHBBF16_TYPES(
125-
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
126-
CTYPE_B b_val;
127-
ET_EXTRACT_SCALAR(b, b_val);
128-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
129-
CTYPE_IN alpha_val;
130-
ET_EXTRACT_SCALAR(alpha, alpha_val);
131-
132-
const size_t n = a.numel();
133-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
134-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
135-
for (auto i = 0; i < n; ++i) {
136-
out_data[i] = static_cast<CTYPE_OUT>(
137-
static_cast<CTYPE_IN>(a_data[i]) +
138-
alpha_val * b_casted);
139-
}
140-
});
141-
});
142-
});
115+
ET_SWITCH_REALB_TYPES(
116+
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
117+
ET_SWITCH_REALHBBF16_TYPES(
118+
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
119+
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
120+
CTYPE_IN alpha_val;
121+
ET_KERNEL_CHECK(
122+
ctx,
123+
utils::extract_scalar(alpha, &alpha_val),
124+
InvalidArgument, );
125+
126+
const size_t n = a.numel();
127+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
128+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
129+
for (auto i = 0; i < n; ++i) {
130+
out_data[i] = static_cast<CTYPE_OUT>(
131+
static_cast<CTYPE_IN>(a_data[i]) +
132+
alpha_val * b_casted);
133+
}
134+
});
135+
});
143136
});
144137
}
145138

kernels/portable/cpu/op_add.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ Tensor& add_out(
5151
static constexpr const char op_name[] = "add.out";
5252

5353
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
54+
CTYPE_COMPUTE val_alpha;
55+
ET_KERNEL_CHECK(
56+
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
5557
utils::apply_bitensor_elementwise_fn<
5658
CTYPE_COMPUTE,
5759
op_name,
@@ -103,7 +105,9 @@ Tensor& add_scalar_out(
103105

104106
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105107
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106-
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
108+
CTYPE_COMPUTE val_alpha;
109+
ET_KERNEL_CHECK(
110+
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
107111
auto val_alpha_times_b = val_alpha * val_b;
108112
utils::apply_unitensor_elementwise_fn<
109113
CTYPE_COMPUTE,

kernels/test/op_add_test.cpp

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
#include <gtest/gtest.h>
1717

18-
#include <iostream>
19-
2018
using namespace ::testing;
2119
using executorch::aten::Scalar;
2220
using executorch::aten::ScalarType;
@@ -231,6 +229,16 @@ class OpAddOutKernelTest : public OperatorTest {
231229
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
232230
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
233231
}
232+
233+
template <ScalarType DTYPE>
234+
void expect_bad_alpha_value_dies(Scalar bad_value) {
235+
TensorFactory<DTYPE> tf;
236+
Tensor a = tf.ones({2, 2});
237+
Tensor b = tf.ones({2, 2});
238+
Tensor out = tf.zeros({2, 2});
239+
240+
ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, bad_value, out));
241+
}
234242
};
235243

236244
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -242,6 +250,16 @@ class OpAddScalarOutKernelTest : public OperatorTest {
242250
Tensor& out) {
243251
return torch::executor::aten::add_outf(context_, self, other, alpha, out);
244252
}
253+
254+
template <ScalarType DTYPE>
255+
void expect_bad_alpha_value_dies(Scalar bad_value) {
256+
TensorFactory<DTYPE> tf;
257+
Tensor a = tf.ones({2, 2});
258+
Scalar b = 1;
259+
Tensor out = tf.zeros({2, 2});
260+
261+
ET_EXPECT_KERNEL_FAILURE(context_, op_add_scalar_out(a, b, bad_value, out));
262+
}
245263
};
246264

247265
/**
@@ -794,3 +812,93 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
794812
op_add_scalar_out(self, other, alpha, out);
795813
EXPECT_TENSOR_CLOSE(out, out_expected);
796814
}
815+
816+
TEST_F(OpAddOutKernelTest, ByteTensorTooLargeAlphaDies) {
817+
// Cannot be represented by a uint8_t.
818+
expect_bad_alpha_value_dies<ScalarType::Byte>(256);
819+
}
820+
821+
TEST_F(OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
822+
// Cannot be represented by a uint8_t.
823+
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
824+
}
825+
826+
#ifndef USE_ATEN_LIB
827+
TEST_F(OpAddOutKernelTest, IntTensorTooSmallAlphaDies) {
828+
// Cannot be represented by a int32_t.
829+
expect_bad_alpha_value_dies<ScalarType::Int>(-2147483649);
830+
}
831+
832+
TEST_F(OpAddOutKernelTest, IntTensorTooLargeAlphaDies) {
833+
// Cannot be represented by a int32_t.
834+
expect_bad_alpha_value_dies<ScalarType::Int>(2147483648);
835+
}
836+
#endif
837+
838+
TEST_F(OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
839+
// Cannot be represented by a uint32_t.
840+
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
841+
}
842+
843+
TEST_F(OpAddOutKernelTest, FloatTensorTooSmallAlphaDies) {
844+
// Cannot be represented by a float.
845+
expect_bad_alpha_value_dies<ScalarType::Float>(-3.41e+38);
846+
}
847+
848+
TEST_F(OpAddOutKernelTest, FloatTensorTooLargeAlphaDies) {
849+
// Cannot be represented by a float.
850+
expect_bad_alpha_value_dies<ScalarType::Float>(3.41e+38);
851+
}
852+
853+
TEST_F(OpAddOutKernelTest, HalfTensorTooLargeAlphaDies) {
854+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
855+
GTEST_SKIP() << "Portable kernel does the computation in float";
856+
}
857+
// Cannot be represented by a float.
858+
expect_bad_alpha_value_dies<ScalarType::Half>(65505.0);
859+
}
860+
861+
TEST_F(OpAddScalarOutKernelTest, ByteTensorTooLargeAlphaDies) {
862+
// Cannot be represented by a uint8_t.
863+
expect_bad_alpha_value_dies<ScalarType::Byte>(256);
864+
}
865+
866+
TEST_F(OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
867+
// Cannot be represented by a uint8_t.
868+
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
869+
}
870+
871+
#ifndef USE_ATEN_LIB
872+
TEST_F(OpAddScalarOutKernelTest, IntTensorTooSmallAlphaDies) {
873+
// Cannot be represented by a int32_t.
874+
expect_bad_alpha_value_dies<ScalarType::Int>(-2147483649);
875+
}
876+
877+
TEST_F(OpAddScalarOutKernelTest, IntTensorTooLargeAlphaDies) {
878+
// Cannot be represented by a int32_t.
879+
expect_bad_alpha_value_dies<ScalarType::Int>(2147483648);
880+
}
881+
#endif
882+
883+
TEST_F(OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
884+
// Cannot be represented by a uint32_t.
885+
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
886+
}
887+
888+
TEST_F(OpAddScalarOutKernelTest, FloatTensorTooSmallAlphaDies) {
889+
// Cannot be represented by a float.
890+
expect_bad_alpha_value_dies<ScalarType::Float>(-3.41e+38);
891+
}
892+
893+
TEST_F(OpAddScalarOutKernelTest, FloatTensorTooLargeAlphaDies) {
894+
// Cannot be represented by a float.
895+
expect_bad_alpha_value_dies<ScalarType::Float>(3.41e+38);
896+
}
897+
898+
TEST_F(OpAddScalarOutKernelTest, HalfTensorTooLargeAlphaDies) {
899+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
900+
GTEST_SKIP() << "Portable kernel does the computation in float";
901+
}
902+
// Cannot be represented by a float.
903+
expect_bad_alpha_value_dies<ScalarType::Half>(65505.0);
904+
}

0 commit comments

Comments
 (0)