Skip to content

Commit c9b96dd

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable Half & Bfloat16: op_eq & op_full
Differential Revision: D63863399
1 parent 13408b9 commit c9b96dd

File tree

4 files changed

+64
-41
lines changed

4 files changed

+64
-41
lines changed

kernels/portable/cpu/op_eq.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,27 @@ Tensor& eq_tensor_out(
3737
ET_KERNEL_CHECK(
3838
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
3939

40-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
41-
ET_SWITCH_REAL_TYPES_AND(
42-
Bool, b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
43-
using CTYPE_IN =
44-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
45-
ET_DCHECK(
46-
CppTypeToScalarType<CTYPE_IN>::value ==
47-
promoteTypes(a_type, b_type));
48-
ET_SWITCH_REAL_TYPES_AND(
49-
Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() {
50-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
51-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
52-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
53-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
54-
bool value = a_casted == b_casted;
55-
return static_cast<CTYPE_OUT>(value);
56-
},
57-
a,
58-
b,
59-
out);
60-
});
61-
});
40+
constexpr auto name = "eq.Tensor_out";
41+
42+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
43+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
44+
using CTYPE_IN =
45+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
46+
ET_DCHECK(
47+
CppTypeToScalarType<CTYPE_IN>::value == promoteTypes(a_type, b_type));
48+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
49+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
50+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
51+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
52+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
53+
bool value = a_casted == b_casted;
54+
return static_cast<CTYPE_OUT>(value);
55+
},
56+
a,
57+
b,
58+
out);
59+
});
60+
});
6261
});
6362

6463
return out;
@@ -86,27 +85,28 @@ Tensor& eq_scalar_out(
8685
ET_KERNEL_CHECK(
8786
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
8887

89-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
90-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
88+
constexpr auto name = "eq.Scalar_out";
89+
90+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
91+
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
9192
using CTYPE_IN =
9293
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
9394
ET_DCHECK(
9495
CppTypeToScalarType<CTYPE_IN>::value == promoteTypes(a_type, b_type));
95-
ET_SWITCH_REAL_TYPES_AND(
96-
Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() {
97-
CTYPE_B val_b = 0;
98-
utils::extract_scalar(b, &val_b);
99-
apply_unary_map_fn(
100-
[val_b](const CTYPE_A val_a) {
101-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
102-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
103-
bool value = a_casted == b_casted;
104-
return static_cast<CTYPE_OUT>(value);
105-
},
106-
a.const_data_ptr<CTYPE_A>(),
107-
out.mutable_data_ptr<CTYPE_OUT>(),
108-
out.numel());
109-
});
96+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
97+
CTYPE_B val_b = 0;
98+
utils::extract_scalar(b, &val_b);
99+
apply_unary_map_fn(
100+
[val_b](const CTYPE_A val_a) {
101+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
102+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
103+
bool value = a_casted == b_casted;
104+
return static_cast<CTYPE_OUT>(value);
105+
},
106+
a.const_data_ptr<CTYPE_A>(),
107+
out.mutable_data_ptr<CTYPE_OUT>(),
108+
out.numel());
109+
});
110110
});
111111
});
112112

kernels/portable/cpu/op_full.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Tensor& full_out(
4040
CTYPE_VAL val;
4141
utils::extract_scalar(fill_value, &val);
4242

43-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
43+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
4444
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
4545
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4646
for (size_t i = 0; i < out.numel(); ++i) {

kernels/test/op_eq_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class OpEqScalarOutTest : public OperatorTest {
6363

6464
TEST_F(OpEqScalarOutTest, AllRealInputBoolOutputSupport) {
6565
#define TEST_ENTRY(ctype, dtype) test_eq_scalar_out<ScalarType::dtype>();
66-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
66+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
6767
#undef TEST_ENTRY
6868
}
6969

kernels/test/op_full_test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,26 @@ TEST_F(OpFullOutTest, ZeroDim) {
122122
op_full_out(sizes, true, out);
123123
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
124124
}
125+
126+
TEST_F(OpFullOutTest, BFloat16Support) {
127+
TensorFactory<ScalarType::BFloat16> tf;
128+
129+
std::vector<int64_t> sizes_int64_t_vec = {2, 3};
130+
std::vector<int32_t> sizes_in32_t_vec = {2, 3};
131+
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
132+
133+
// Boolean Scalar
134+
Tensor out = tf.zeros(sizes_in32_t_vec);
135+
op_full_out(sizes, true, out);
136+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
137+
138+
// Integral Scalar
139+
out = tf.zeros(sizes_in32_t_vec);
140+
op_full_out(sizes, 1, out);
141+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
142+
143+
// Floating Point Scalar
144+
out = tf.zeros(sizes_in32_t_vec);
145+
op_full_out(sizes, 3.1415926535, out);
146+
EXPECT_TENSOR_EQ(out, tf.full(sizes_in32_t_vec, 3.1415926535));
147+
}

0 commit comments

Comments
 (0)