Skip to content

Commit 94b7146

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable Half/BF16: abs, full, gelu, hardtanh, logit, neg, sign (#5856)
Summary: Pull Request resolved: #5856 Differential Revision: D63863399
1 parent d094b09 commit 94b7146

File tree

15 files changed

+176
-13
lines changed

15 files changed

+176
-13
lines changed

kernels/portable/cpu/op_abs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3131
ET_KERNEL_CHECK(
3232
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3333

34-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
34+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
3535
apply_unary_map_fn(
3636
[](const CTYPE val_in) {
3737
if (val_in < 0) {

kernels/portable/cpu/op_full.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Tensor& full_out(
4040
CTYPE_VAL val;
4141
utils::extract_scalar(fill_value, &val);
4242

43-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
43+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
4444
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
4545
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4646
for (size_t i = 0; i < out.numel(); ++i) {

kernels/portable/cpu/op_gelu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Tensor& gelu_out(
3737
ET_KERNEL_CHECK(
3838
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3939

40-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
40+
ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() {
4141
if (approximate == "tanh") {
4242
apply_unary_map_fn(
4343
[](const CTYPE x) {

kernels/portable/cpu/op_hardtanh.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Tensor& hardtanh_out(
4646

4747
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4848

49-
ET_SWITCH_REAL_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
49+
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
5050
CTYPE min_casted;
5151
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() {
5252
CTYPE_MIN min_val;

kernels/portable/cpu/op_logit.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ Tensor& logit_out(
3535

3636
ScalarType in_type = in.scalar_type();
3737
ScalarType out_type = out.scalar_type();
38-
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "logit.out", CTYPE_IN, [&] {
39-
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "logit.out", CTYPE_OUT, [&] {
38+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "logit.out", CTYPE_IN, [&] {
39+
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "logit.out", CTYPE_OUT, [&] {
4040
apply_unary_map_fn(
4141
[eps](const CTYPE_IN val_in) {
4242
CTYPE_OUT xi = static_cast<CTYPE_OUT>(val_in);

kernels/portable/cpu/op_neg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Tensor& neg_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333
ET_KERNEL_CHECK(
3434
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3535

36-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
36+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] {
3737
apply_unary_map_fn(
3838
[](const CTYPE val_in) { return static_cast<CTYPE>(-val_in); },
3939
in.const_data_ptr<CTYPE>(),

kernels/portable/cpu/op_sign.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Tensor& sign_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3939
if (in.scalar_type() == exec_aten::ScalarType::Bool) {
4040
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
4141
} else {
42-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "sign.out", CTYPE, [&] {
42+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "sign.out", CTYPE, [&] {
4343
apply_unary_map_fn(
4444
[](const CTYPE val_in) {
4545
if (std::isnan(val_in)) {

kernels/portable/cpu/util/math_util.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ INT_T max_override(INT_T a, INT_T b) {
9696

9797
template <
9898
typename T,
99-
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
100-
type = true>
99+
typename std::enable_if<
100+
std::is_same<T, exec_aten::Half>::value ||
101+
std::is_same<T, exec_aten::BFloat16>::value,
102+
bool>::type = true>
101103
T min_override(T a, T b) {
102104
const auto float_a = static_cast<float>(a);
103105
if (std::isnan(float_a)) {
@@ -116,8 +118,10 @@ T min_override(T a, T b) {
116118

117119
template <
118120
typename T,
119-
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
120-
type = true>
121+
typename std::enable_if<
122+
std::is_same<T, exec_aten::Half>::value ||
123+
std::is_same<T, exec_aten::BFloat16>::value,
124+
bool>::type = true>
121125
T max_override(T a, T b) {
122126
const auto float_a = static_cast<float>(a);
123127
if (std::isnan(float_a)) {

kernels/test/op_abs_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,44 @@ class OpAbsTest : public OperatorTest {
2424
Tensor& op_abs_out(const Tensor& self, Tensor& out) {
2525
return torch::executor::aten::abs_outf(context_, self, out);
2626
}
27+
28+
template <ScalarType DTYPE>
29+
void test_dtype() {
30+
TensorFactory<DTYPE> tf;
31+
32+
Tensor in = tf.make({2, 3}, {-3, -2, -1, 0, 1, 2});
33+
Tensor out = tf.zeros({2, 3});
34+
Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, 1, 2});
35+
36+
Tensor ret = op_abs_out(in, out);
37+
38+
EXPECT_TENSOR_EQ(out, ret);
39+
EXPECT_TENSOR_EQ(out, expected);
40+
}
41+
42+
template <>
43+
void test_dtype<ScalarType::Byte>() {
44+
TensorFactory<ScalarType::Byte> tf;
45+
46+
Tensor in = tf.make({2, 3}, {253, 254, 255, 0, 1, 2});
47+
Tensor out = tf.zeros({2, 3});
48+
Tensor expected = tf.make({2, 3}, {253, 254, 255, 0, 1, 2});
49+
50+
Tensor ret = op_abs_out(in, out);
51+
52+
EXPECT_TENSOR_EQ(out, ret);
53+
EXPECT_TENSOR_EQ(out, expected);
54+
}
2755
};
2856

57+
TEST_F(OpAbsTest, AllRealHBF16Input) {
58+
#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \
59+
test_dtype<ScalarType::INPUT_DTYPE>();
60+
61+
ET_FORALL_REALHBF16_TYPES(TEST_KERNEL);
62+
#undef TEST_KERNEL
63+
}
64+
2965
TEST_F(OpAbsTest, SanityCheck) {
3066
TensorFactory<ScalarType::Float> tf;
3167

kernels/test/op_full_test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,26 @@ TEST_F(OpFullOutTest, ZeroDim) {
122122
op_full_out(sizes, true, out);
123123
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
124124
}
125+
126+
TEST_F(OpFullOutTest, BFloat16Support) {
127+
TensorFactory<ScalarType::BFloat16> tf;
128+
129+
std::vector<int64_t> sizes_int64_t_vec = {2, 3};
130+
std::vector<int32_t> sizes_in32_t_vec = {2, 3};
131+
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
132+
133+
// Boolean Scalar
134+
Tensor out = tf.zeros(sizes_in32_t_vec);
135+
op_full_out(sizes, true, out);
136+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
137+
138+
// Integral Scalar
139+
out = tf.zeros(sizes_in32_t_vec);
140+
op_full_out(sizes, 1, out);
141+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
142+
143+
// Floating Point Scalar
144+
out = tf.zeros(sizes_in32_t_vec);
145+
op_full_out(sizes, 3.1415926535, out);
146+
EXPECT_TENSOR_EQ(out, tf.full(sizes_in32_t_vec, 3.1415926535));
147+
}

0 commit comments

Comments
 (0)