Skip to content

Commit ee32848

Browse files
swolchokfacebook-github-bot
authored andcommitted
Support bf16 for isinf/isnan (pytorch#5690)
Summary: Pull Request resolved: pytorch#5690 These two used the same pattern function. ghstack-source-id: 245578277 exported-using-ghexport Reviewed By: manuelcandales Differential Revision: D63474070 fbshipit-source-id: eee8a3844757efdea58788abe07cb3694358441e
1 parent 6a27589 commit ee32848

File tree

3 files changed

+35
-109
lines changed

3 files changed

+35
-109
lines changed

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Tensor& unary_ufunc_realhb_to_bool(
4343

4444
const auto in_type = in.scalar_type();
4545

46-
ET_SWITCH_REALHB_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {
46+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {
4747
apply_unary_map_fn(
4848
[fn](const CTYPE_IN val_in) { return fn(val_in); },
4949
in.const_data_ptr<CTYPE_IN>(),

kernels/test/op_isinf_test.cpp

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,66 +25,29 @@ class OpIsInfTest : public OperatorTest {
2525
Tensor& op_isinf_out(const Tensor& self, Tensor& out) {
2626
return torch::executor::aten::isinf_outf(context_, self, out);
2727
}
28-
};
29-
30-
TEST_F(OpIsInfTest, SanityCheckFloat) {
31-
TensorFactory<ScalarType::Float> tf;
32-
TensorFactory<ScalarType::Bool> tfb;
3328

34-
Tensor in = tf.make(
35-
{1, 5}, {-1.0, 0.0, 1.0, NAN, std::numeric_limits<float>::infinity()});
36-
Tensor out = tfb.zeros({1, 5});
37-
Tensor expected = tfb.make({1, 5}, {false, false, false, false, true});
29+
template <ScalarType DTYPE>
30+
void test_sanity_check() {
31+
TensorFactory<DTYPE> tf;
32+
TensorFactory<ScalarType::Bool> tfb;
3833

39-
Tensor ret = op_isinf_out(in, out);
34+
using CTYPE = typename TensorFactory<DTYPE>::ctype;
35+
Tensor in = tf.make(
36+
{1, 5}, {-1, 0, 1, NAN, std::numeric_limits<CTYPE>::infinity()});
37+
Tensor out = tfb.zeros({1, 5});
38+
Tensor expected = tfb.make({1, 5}, {false, false, false, false, true});
4039

41-
EXPECT_TENSOR_EQ(out, ret);
42-
EXPECT_TENSOR_EQ(out, expected);
43-
}
40+
Tensor ret = op_isinf_out(in, out);
4441

45-
TEST_F(OpIsInfTest, SanityCheckHalf) {
46-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
47-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
42+
EXPECT_TENSOR_EQ(out, ret);
43+
EXPECT_TENSOR_EQ(out, expected);
4844
}
49-
TensorFactory<ScalarType::Half> tf;
50-
TensorFactory<ScalarType::Bool> tfb;
51-
52-
Tensor in = tf.make(
53-
{1, 5}, {-1.0, 0.0, 1.0, NAN, std::numeric_limits<float>::infinity()});
54-
Tensor out = tfb.zeros({1, 5});
55-
Tensor expected = tfb.make({1, 5}, {false, false, false, false, true});
56-
57-
Tensor ret = op_isinf_out(in, out);
58-
59-
EXPECT_TENSOR_EQ(out, ret);
60-
EXPECT_TENSOR_EQ(out, expected);
61-
}
62-
63-
TEST_F(OpIsInfTest, SanityCheckByte) {
64-
TensorFactory<ScalarType::Byte> tf;
65-
TensorFactory<ScalarType::Bool> tfb;
66-
67-
Tensor in = tf.make({1, 5}, {1, 2, 3, 4, 5});
68-
Tensor out = tfb.zeros({1, 5});
69-
Tensor expected = tfb.make({1, 5}, {false, false, false, false, false});
70-
71-
Tensor ret = op_isinf_out(in, out);
72-
73-
EXPECT_TENSOR_EQ(out, ret);
74-
EXPECT_TENSOR_EQ(out, expected);
75-
}
76-
77-
TEST_F(OpIsInfTest, SanityCheckBool) {
78-
TensorFactory<ScalarType::Bool> tfb;
79-
80-
Tensor in = tfb.make({1, 5}, {true, false, true, true, false});
81-
Tensor out = tfb.zeros({1, 5});
82-
Tensor expected = tfb.make({1, 5}, {false, false, false, false, false});
83-
84-
Tensor ret = op_isinf_out(in, out);
45+
};
8546

86-
EXPECT_TENSOR_EQ(out, ret);
87-
EXPECT_TENSOR_EQ(out, expected);
47+
TEST_F(OpIsInfTest, SanityCheck) {
48+
#define TEST_ENTRY(ctype, dtype) test_sanity_check<ScalarType::dtype>();
49+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
50+
#undef TEST_ENTRY
8851
}
8952

9053
TEST_F(OpIsInfTest, SanityCheckOutDtype) {

kernels/test/op_isnan_test.cpp

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,66 +25,29 @@ class OpIsNanTest : public OperatorTest {
2525
Tensor& op_isnan_out(const Tensor& self, Tensor& out) {
2626
return torch::executor::aten::isnan_outf(context_, self, out);
2727
}
28-
};
29-
30-
TEST_F(OpIsNanTest, SanityCheckFloat) {
31-
TensorFactory<ScalarType::Float> tf;
32-
TensorFactory<ScalarType::Bool> tfb;
3328

34-
Tensor in = tf.make(
35-
{1, 5}, {-1.0, 0.0, 1.0, NAN, std::numeric_limits<float>::infinity()});
36-
Tensor out = tfb.zeros({1, 5});
37-
Tensor expected = tfb.make({1, 5}, {false, false, false, true, false});
29+
template <ScalarType DTYPE>
30+
void test_sanity_check() {
31+
TensorFactory<DTYPE> tf;
32+
TensorFactory<ScalarType::Bool> tfb;
3833

39-
Tensor ret = op_isnan_out(in, out);
34+
using CTYPE = typename TensorFactory<DTYPE>::ctype;
35+
Tensor in = tf.make(
36+
{1, 5}, {-1, 0, 1, NAN, std::numeric_limits<CTYPE>::infinity()});
37+
Tensor out = tfb.zeros({1, 5});
38+
Tensor expected = tfb.make({1, 5}, {false, false, false, true, false});
4039

41-
EXPECT_TENSOR_EQ(out, ret);
42-
EXPECT_TENSOR_EQ(out, expected);
43-
}
40+
Tensor ret = op_isnan_out(in, out);
4441

45-
TEST_F(OpIsNanTest, SanityCheckHalf) {
46-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
47-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
42+
EXPECT_TENSOR_EQ(out, ret);
43+
EXPECT_TENSOR_EQ(out, expected);
4844
}
49-
TensorFactory<ScalarType::Float> tf;
50-
TensorFactory<ScalarType::Bool> tfb;
51-
52-
Tensor in = tf.make(
53-
{1, 5}, {-1.0, 0.0, 1.0, NAN, std::numeric_limits<float>::infinity()});
54-
Tensor out = tfb.zeros({1, 5});
55-
Tensor expected = tfb.make({1, 5}, {false, false, false, true, false});
56-
57-
Tensor ret = op_isnan_out(in, out);
58-
59-
EXPECT_TENSOR_EQ(out, ret);
60-
EXPECT_TENSOR_EQ(out, expected);
61-
}
62-
63-
TEST_F(OpIsNanTest, SanityCheckByte) {
64-
TensorFactory<ScalarType::Byte> tf;
65-
TensorFactory<ScalarType::Bool> tfb;
66-
67-
Tensor in = tf.make({1, 5}, {1, 2, 3, 4, 5});
68-
Tensor out = tfb.zeros({1, 5});
69-
Tensor expected = tfb.make({1, 5}, {false, false, false, false, false});
70-
71-
Tensor ret = op_isnan_out(in, out);
72-
73-
EXPECT_TENSOR_EQ(out, ret);
74-
EXPECT_TENSOR_EQ(out, expected);
75-
}
76-
77-
TEST_F(OpIsNanTest, SanityCheckBool) {
78-
TensorFactory<ScalarType::Bool> tfb;
79-
80-
Tensor in = tfb.make({1, 5}, {true, false, true, true, false});
81-
Tensor out = tfb.zeros({1, 5});
82-
Tensor expected = tfb.make({1, 5}, {false, false, false, false, false});
83-
84-
Tensor ret = op_isnan_out(in, out);
45+
};
8546

86-
EXPECT_TENSOR_EQ(out, ret);
87-
EXPECT_TENSOR_EQ(out, expected);
47+
TEST_F(OpIsNanTest, SanityCheck) {
48+
#define TEST_ENTRY(ctype, dtype) test_sanity_check<ScalarType::dtype>();
49+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
50+
#undef TEST_ENTRY
8851
}
8952

9053
TEST_F(OpIsNanTest, SanityCheckOutDtype) {

0 commit comments

Comments
 (0)