Skip to content

Commit 77156b9

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable Half/BF16: abs, full, gelu, neg (#5856)
Summary: Pull Request resolved: #5856 Differential Revision: D63863399
1 parent d094b09 commit 77156b9

File tree

8 files changed

+103
-4
lines changed

8 files changed

+103
-4
lines changed

kernels/portable/cpu/op_abs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3131
ET_KERNEL_CHECK(
3232
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3333

34-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
34+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
3535
apply_unary_map_fn(
3636
[](const CTYPE val_in) {
3737
if (val_in < 0) {

kernels/portable/cpu/op_full.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Tensor& full_out(
4040
CTYPE_VAL val;
4141
utils::extract_scalar(fill_value, &val);
4242

43-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
43+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
4444
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
4545
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4646
for (size_t i = 0; i < out.numel(); ++i) {

kernels/portable/cpu/op_gelu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Tensor& gelu_out(
3737
ET_KERNEL_CHECK(
3838
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3939

40-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
40+
ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
4141
if (approximate == "tanh") {
4242
apply_unary_map_fn(
4343
[](const CTYPE x) {

kernels/portable/cpu/op_neg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Tensor& neg_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333
ET_KERNEL_CHECK(
3434
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3535

36-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
36+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
3737
apply_unary_map_fn(
3838
[](const CTYPE val_in) { return static_cast<CTYPE>(-val_in); },
3939
in.const_data_ptr<CTYPE>(),

kernels/test/op_abs_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,44 @@ class OpAbsTest : public OperatorTest {
2424
Tensor& op_abs_out(const Tensor& self, Tensor& out) {
2525
return torch::executor::aten::abs_outf(context_, self, out);
2626
}
27+
28+
template <ScalarType DTYPE>
29+
void test_dtype() {
30+
TensorFactory<DTYPE> tf;
31+
32+
Tensor in = tf.make({2, 3}, {-3, -2, -1, 0, 1, 2});
33+
Tensor out = tf.zeros({2, 3});
34+
Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, 1, 2});
35+
36+
Tensor ret = op_abs_out(in, out);
37+
38+
EXPECT_TENSOR_EQ(out, ret);
39+
EXPECT_TENSOR_EQ(out, expected);
40+
}
41+
42+
template <>
43+
void test_dtype<ScalarType::Byte>() {
44+
TensorFactory<ScalarType::Byte> tf;
45+
46+
Tensor in = tf.make({2, 3}, {253, 254, 255, 0, 1, 2});
47+
Tensor out = tf.zeros({2, 3});
48+
Tensor expected = tf.make({2, 3}, {253, 254, 255, 0, 1, 2});
49+
50+
Tensor ret = op_abs_out(in, out);
51+
52+
EXPECT_TENSOR_EQ(out, ret);
53+
EXPECT_TENSOR_EQ(out, expected);
54+
}
2755
};
2856

57+
TEST_F(OpAbsTest, AllRealHBF16Input) {
58+
#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \
59+
test_dtype<ScalarType::INPUT_DTYPE>();
60+
61+
ET_FORALL_REALHBF16_TYPES(TEST_KERNEL);
62+
#undef TEST_KERNEL
63+
}
64+
2965
TEST_F(OpAbsTest, SanityCheck) {
3066
TensorFactory<ScalarType::Float> tf;
3167

kernels/test/op_full_test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,26 @@ TEST_F(OpFullOutTest, ZeroDim) {
122122
op_full_out(sizes, true, out);
123123
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
124124
}
125+
126+
TEST_F(OpFullOutTest, BFloat16Support) {
127+
TensorFactory<ScalarType::BFloat16> tf;
128+
129+
std::vector<int64_t> sizes_int64_t_vec = {2, 3};
130+
std::vector<int32_t> sizes_in32_t_vec = {2, 3};
131+
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
132+
133+
// Boolean Scalar
134+
Tensor out = tf.zeros(sizes_in32_t_vec);
135+
op_full_out(sizes, true, out);
136+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
137+
138+
// Integral Scalar
139+
out = tf.zeros(sizes_in32_t_vec);
140+
op_full_out(sizes, 1, out);
141+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
142+
143+
// Floating Point Scalar
144+
out = tf.zeros(sizes_in32_t_vec);
145+
op_full_out(sizes, 3.1415926535, out);
146+
EXPECT_TENSOR_EQ(out, tf.full(sizes_in32_t_vec, 3.1415926535));
147+
}

kernels/test/op_gelu_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class OpGeluTest : public OperatorTest {
6666
}
6767
};
6868

69+
TEST_F(OpGeluTest, HalfTensors) {
70+
test_gelu_execution<ScalarType::Half>();
71+
}
72+
6973
TEST_F(OpGeluTest, FloatTensors) {
7074
test_gelu_execution<ScalarType::Float>();
7175
}

kernels/test/op_neg_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,44 @@ class OpNegTest : public OperatorTest {
2424
Tensor& op_neg_out(const Tensor& self, Tensor& out) {
2525
return torch::executor::aten::neg_outf(context_, self, out);
2626
}
27+
28+
template <ScalarType DTYPE>
29+
void test_dtype() {
30+
TensorFactory<DTYPE> tf;
31+
32+
Tensor in = tf.make({2, 3}, {-3, -2, -1, 0, 1, 2});
33+
Tensor out = tf.zeros({2, 3});
34+
Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, -1, -2});
35+
36+
Tensor ret = op_neg_out(in, out);
37+
38+
EXPECT_TENSOR_EQ(out, ret);
39+
EXPECT_TENSOR_EQ(out, expected);
40+
}
41+
42+
template <>
43+
void test_dtype<ScalarType::Byte>() {
44+
TensorFactory<ScalarType::Byte> tf;
45+
46+
Tensor in = tf.make({2, 3}, {253, 254, 255, 0, 1, 2});
47+
Tensor out = tf.zeros({2, 3});
48+
Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, 255, 254});
49+
50+
Tensor ret = op_neg_out(in, out);
51+
52+
EXPECT_TENSOR_EQ(out, ret);
53+
EXPECT_TENSOR_EQ(out, expected);
54+
}
2755
};
2856

57+
TEST_F(OpNegTest, AllRealHBF16Input) {
58+
#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \
59+
test_dtype<ScalarType::INPUT_DTYPE>();
60+
61+
ET_FORALL_REALHBF16_TYPES(TEST_KERNEL);
62+
#undef TEST_KERNEL
63+
}
64+
2965
TEST_F(OpNegTest, SanityCheck) {
3066
TensorFactory<ScalarType::Float> tf;
3167

0 commit comments

Comments
 (0)