Skip to content

Commit a0e6284

Browse files
[ET][Portable] Test bad alpha values: op_add
- Tests for bad alpha values: add.Tensor & add.Scalar - Fixes alpha handling in portable/optimized op_add kernels - Eliminates usage of ET_SWITCH_SCALAR_OBJ_TYPES in optimized op_add kernel Differential Revision: [D77325770](https://our.internmc.facebook.com/intern/diff/D77325770/) [ghstack-poisoned]
1 parent b27ed23 commit a0e6284

File tree

3 files changed

+159
-42
lines changed

3 files changed

+159
-42
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Tensor& opt_add_out(
4646
CTYPE alpha_val;
4747
ET_KERNEL_CHECK(
4848
ctx,
49-
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
49+
utils::extract_scalar(alpha, &alpha_val),
5050
InvalidArgument, );
5151
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
5252
CTYPE b_casted = static_cast<CTYPE>(b_val);
@@ -81,7 +81,6 @@ Tensor& opt_add_scalar_out(
8181
(void)ctx;
8282

8383
ScalarType a_type = a.scalar_type();
84-
ScalarType b_type = utils::get_scalar_dtype(b);
8584
ScalarType common_type =
8685
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
8786
ScalarType out_type = out.scalar_type();
@@ -99,47 +98,45 @@ Tensor& opt_add_scalar_out(
9998
if (a_type == common_type && a_type == out_type &&
10099
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
101100
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
102-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
103-
CTYPE_B b_val;
104-
ET_EXTRACT_SCALAR(b, b_val);
105-
CTYPE b_casted = static_cast<CTYPE>(b_val);
106-
CTYPE alpha_val;
107-
ET_EXTRACT_SCALAR(alpha, alpha_val);
108-
109-
using Vec = at::vec::Vectorized<CTYPE>;
110-
at::vec::map<CTYPE>(
111-
[alpha_val, b_casted](Vec x) {
112-
return x + Vec(alpha_val * b_casted);
113-
},
114-
out.mutable_data_ptr<CTYPE>(),
115-
a.const_data_ptr<CTYPE>(),
116-
out.numel());
117-
});
101+
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
102+
CTYPE alpha_val;
103+
ET_KERNEL_CHECK(
104+
ctx,
105+
utils::extract_scalar(alpha, &alpha_val),
106+
InvalidArgument, );
107+
108+
using Vec = at::vec::Vectorized<CTYPE>;
109+
at::vec::map<CTYPE>(
110+
[alpha_val, b_casted](Vec x) {
111+
return x + Vec(alpha_val * b_casted);
112+
},
113+
out.mutable_data_ptr<CTYPE>(),
114+
a.const_data_ptr<CTYPE>(),
115+
out.numel());
118116
});
119117
} else {
120118
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
121-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
122-
ET_SWITCH_REALB_TYPES(
123-
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
124-
ET_SWITCH_REALHBBF16_TYPES(
125-
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
126-
CTYPE_B b_val;
127-
ET_EXTRACT_SCALAR(b, b_val);
128-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
129-
CTYPE_IN alpha_val;
130-
ET_EXTRACT_SCALAR(alpha, alpha_val);
131-
132-
const size_t n = a.numel();
133-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
134-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
135-
for (auto i = 0; i < n; ++i) {
136-
out_data[i] = static_cast<CTYPE_OUT>(
137-
static_cast<CTYPE_IN>(a_data[i]) +
138-
alpha_val * b_casted);
139-
}
140-
});
141-
});
142-
});
119+
ET_SWITCH_REALB_TYPES(
120+
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
121+
ET_SWITCH_REALHBBF16_TYPES(
122+
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
123+
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
124+
CTYPE_IN alpha_val;
125+
ET_KERNEL_CHECK(
126+
ctx,
127+
utils::extract_scalar(alpha, &alpha_val),
128+
InvalidArgument, );
129+
130+
const size_t n = a.numel();
131+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
132+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
133+
for (auto i = 0; i < n; ++i) {
134+
out_data[i] = static_cast<CTYPE_OUT>(
135+
static_cast<CTYPE_IN>(a_data[i]) +
136+
alpha_val * b_casted);
137+
}
138+
});
139+
});
143140
});
144141
}
145142

kernels/portable/cpu/op_add.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ Tensor& add_out(
5151
static constexpr const char op_name[] = "add.out";
5252

5353
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54-
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
54+
CTYPE_COMPUTE val_alpha;
55+
ET_KERNEL_CHECK(
56+
ctx,
57+
utils::extract_scalar(alpha, &val_alpha),
58+
InvalidArgument, );
5559
utils::apply_bitensor_elementwise_fn<
5660
CTYPE_COMPUTE,
5761
op_name,
@@ -103,7 +107,11 @@ Tensor& add_scalar_out(
103107

104108
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105109
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106-
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
110+
CTYPE_COMPUTE val_alpha;
111+
ET_KERNEL_CHECK(
112+
ctx,
113+
utils::extract_scalar(alpha, &val_alpha),
114+
InvalidArgument, );
107115
auto val_alpha_times_b = val_alpha * val_b;
108116
utils::apply_unitensor_elementwise_fn<
109117
CTYPE_COMPUTE,

kernels/test/op_add_test.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,17 @@ class OpAddOutKernelTest : public OperatorTest {
231231
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
232232
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
233233
}
234+
235+
template <ScalarType DTYPE>
236+
void expect_bad_alpha_value_dies(Scalar bad_value) {
237+
TensorFactory<DTYPE> tf;
238+
Tensor a = tf.ones({2, 2});
239+
Tensor b = tf.ones({2, 2});
240+
Tensor out = tf.zeros({2, 2});
241+
242+
ET_EXPECT_KERNEL_FAILURE(
243+
context_, op_add_out(a, b, bad_value, out));
244+
}
234245
};
235246

236247
class OpAddScalarOutKernelTest : public OperatorTest {
@@ -242,6 +253,17 @@ class OpAddScalarOutKernelTest : public OperatorTest {
242253
Tensor& out) {
243254
return torch::executor::aten::add_outf(context_, self, other, alpha, out);
244255
}
256+
257+
template <ScalarType DTYPE>
258+
void expect_bad_alpha_value_dies(Scalar bad_value) {
259+
TensorFactory<DTYPE> tf;
260+
Tensor a = tf.ones({2, 2});
261+
Scalar b = 1;
262+
Tensor out = tf.zeros({2, 2});
263+
264+
ET_EXPECT_KERNEL_FAILURE(
265+
context_, op_add_scalar_out(a, b, bad_value, out));
266+
}
245267
};
246268

247269
/**
@@ -794,3 +816,93 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
794816
op_add_scalar_out(self, other, alpha, out);
795817
EXPECT_TENSOR_CLOSE(out, out_expected);
796818
}
819+
820+
TEST_F(OpAddOutKernelTest, ByteTensorTooLargeAlphaDies) {
821+
// Cannot be represented by a uint8_t.
822+
expect_bad_alpha_value_dies<ScalarType::Byte>(256);
823+
}
824+
825+
TEST_F(OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
826+
// Cannot be represented by a uint8_t.
827+
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
828+
}
829+
830+
#ifndef USE_ATEN_LIB
831+
TEST_F(OpAddOutKernelTest, IntTensorTooSmallAlphaDies) {
832+
// Cannot be represented by a int32_t.
833+
expect_bad_alpha_value_dies<ScalarType::Int>(-2147483649);
834+
}
835+
836+
TEST_F(OpAddOutKernelTest, IntTensorTooLargeAlphaDies) {
837+
// Cannot be represented by a int32_t.
838+
expect_bad_alpha_value_dies<ScalarType::Int>(2147483648);
839+
}
840+
#endif
841+
842+
TEST_F(OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
843+
// Cannot be represented by a uint32_t.
844+
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
845+
}
846+
847+
TEST_F(OpAddOutKernelTest, FloatTensorTooSmallAlphaDies) {
848+
// Cannot be represented by a float.
849+
expect_bad_alpha_value_dies<ScalarType::Float>(-3.41e+38);
850+
}
851+
852+
TEST_F(OpAddOutKernelTest, FloatTensorTooLargeAlphaDies) {
853+
// Cannot be represented by a float.
854+
expect_bad_alpha_value_dies<ScalarType::Float>(3.41e+38);
855+
}
856+
857+
TEST_F(OpAddOutKernelTest, HalfTensorTooLargeAlphaDies) {
858+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
859+
GTEST_SKIP() << "Portable kernel does the computation in float";
860+
}
861+
// Cannot be represented by a float.
862+
expect_bad_alpha_value_dies<ScalarType::Half>(65505.0);
863+
}
864+
865+
TEST_F(OpAddScalarOutKernelTest, ByteTensorTooLargeAlphaDies) {
866+
// Cannot be represented by a uint8_t.
867+
expect_bad_alpha_value_dies<ScalarType::Byte>(256);
868+
}
869+
870+
TEST_F(OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
871+
// Cannot be represented by a uint8_t.
872+
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
873+
}
874+
875+
#ifndef USE_ATEN_LIB
876+
TEST_F(OpAddScalarOutKernelTest, IntTensorTooSmallAlphaDies) {
877+
// Cannot be represented by a int32_t.
878+
expect_bad_alpha_value_dies<ScalarType::Int>(-2147483649);
879+
}
880+
881+
TEST_F(OpAddScalarOutKernelTest, IntTensorTooLargeAlphaDies) {
882+
// Cannot be represented by a int32_t.
883+
expect_bad_alpha_value_dies<ScalarType::Int>(2147483648);
884+
}
885+
#endif
886+
887+
TEST_F(OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
888+
// Cannot be represented by a uint32_t.
889+
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
890+
}
891+
892+
TEST_F(OpAddScalarOutKernelTest, FloatTensorTooSmallAlphaDies) {
893+
// Cannot be represented by a float.
894+
expect_bad_alpha_value_dies<ScalarType::Float>(-3.41e+38);
895+
}
896+
897+
TEST_F(OpAddScalarOutKernelTest, FloatTensorTooLargeAlphaDies) {
898+
// Cannot be represented by a float.
899+
expect_bad_alpha_value_dies<ScalarType::Float>(3.41e+38);
900+
}
901+
902+
TEST_F(OpAddScalarOutKernelTest, HalfTensorTooLargeAlphaDies) {
903+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
904+
GTEST_SKIP() << "Portable kernel does the computation in float";
905+
}
906+
// Cannot be represented by a float.
907+
expect_bad_alpha_value_dies<ScalarType::Half>(65505.0);
908+
}

0 commit comments

Comments
 (0)