diff --git a/kernels/portable/cpu/op_abs.cpp b/kernels/portable/cpu/op_abs.cpp index 5d0fcbaaa45..df530bcd1fa 100644 --- a/kernels/portable/cpu/op_abs.cpp +++ b/kernels/portable/cpu/op_abs.cpp @@ -31,7 +31,7 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] { + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] { apply_unary_map_fn( [](const CTYPE val_in) { if (val_in < 0) { diff --git a/kernels/portable/cpu/op_full.cpp b/kernels/portable/cpu/op_full.cpp index 74b9657204f..30ca7b825fb 100644 --- a/kernels/portable/cpu/op_full.cpp +++ b/kernels/portable/cpu/op_full.cpp @@ -40,7 +40,7 @@ Tensor& full_out( CTYPE_VAL val; utils::extract_scalar(fill_value, &val); - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT val_casted = static_cast(val); auto data_out = out.mutable_data_ptr(); for (size_t i = 0; i < out.numel(); ++i) { diff --git a/kernels/portable/cpu/op_gelu.cpp b/kernels/portable/cpu/op_gelu.cpp index db5d9cbfe71..37840af31ce 100644 --- a/kernels/portable/cpu/op_gelu.cpp +++ b/kernels/portable/cpu/op_gelu.cpp @@ -37,7 +37,7 @@ Tensor& gelu_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() { + ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() { if (approximate == "tanh") { apply_unary_map_fn( [](const CTYPE x) { diff --git a/kernels/portable/cpu/op_hardtanh.cpp b/kernels/portable/cpu/op_hardtanh.cpp index e86edab76b4..56ac77b37fb 100644 --- a/kernels/portable/cpu/op_hardtanh.cpp +++ b/kernels/portable/cpu/op_hardtanh.cpp @@ -46,7 +46,7 @@ Tensor& hardtanh_out( ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); - ET_SWITCH_REAL_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { + ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { CTYPE min_casted; ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() { CTYPE_MIN min_val; diff --git a/kernels/portable/cpu/op_logit.cpp b/kernels/portable/cpu/op_logit.cpp index faba847e844..7985b2aa400 100644 --- a/kernels/portable/cpu/op_logit.cpp +++ b/kernels/portable/cpu/op_logit.cpp @@ -35,8 +35,8 @@ Tensor& logit_out( ScalarType in_type = in.scalar_type(); ScalarType out_type = out.scalar_type(); - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "logit.out", CTYPE_IN, [&] { - ET_SWITCH_FLOAT_TYPES(out_type, ctx, "logit.out", CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "logit.out", CTYPE_IN, [&] { + ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "logit.out", CTYPE_OUT, [&] { apply_unary_map_fn( [eps](const CTYPE_IN val_in) { CTYPE_OUT xi = static_cast(val_in); diff --git a/kernels/portable/cpu/op_neg.cpp b/kernels/portable/cpu/op_neg.cpp index a4e6a8ad256..339bfd8a445 100644 --- a/kernels/portable/cpu/op_neg.cpp +++ b/kernels/portable/cpu/op_neg.cpp @@ -33,7 +33,7 @@ Tensor& neg_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] { + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "neg.out", CTYPE, [&] { apply_unary_map_fn( [](const CTYPE val_in) { return static_cast(-val_in); }, in.const_data_ptr(), diff --git a/kernels/portable/cpu/op_sign.cpp b/kernels/portable/cpu/op_sign.cpp index af3225d4779..a0038114613 100644 --- a/kernels/portable/cpu/op_sign.cpp +++ b/kernels/portable/cpu/op_sign.cpp @@ -39,7 +39,7 @@ Tensor& sign_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { if (in.scalar_type() == exec_aten::ScalarType::Bool) { memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes()); } else { - ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "sign.out", CTYPE, [&] { + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "sign.out", CTYPE, [&] { apply_unary_map_fn( [](const CTYPE val_in) { if (std::isnan(val_in)) { diff --git a/kernels/portable/cpu/util/math_util.h b/kernels/portable/cpu/util/math_util.h index 05935fff389..16ded0631b1 100644 --- a/kernels/portable/cpu/util/math_util.h +++ b/kernels/portable/cpu/util/math_util.h @@ -96,8 +96,10 @@ INT_T max_override(INT_T a, INT_T b) { template < typename T, - typename std::enable_if::value, bool>:: - type = true> + typename std::enable_if< + std::is_same::value || + std::is_same::value, + bool>::type = true> T min_override(T a, T b) { const auto float_a = static_cast(a); if (std::isnan(float_a)) { @@ -116,8 +118,10 @@ T min_override(T a, T b) { template < typename T, - typename std::enable_if::value, bool>:: - type = true> + typename std::enable_if< + std::is_same::value || + std::is_same::value, + bool>::type = true> T max_override(T a, T b) { const auto float_a = static_cast(a); if (std::isnan(float_a)) { diff --git a/kernels/test/op_abs_test.cpp b/kernels/test/op_abs_test.cpp index f596d586d90..925967e953f 100644 --- a/kernels/test/op_abs_test.cpp +++ b/kernels/test/op_abs_test.cpp @@ -24,8 +24,44 @@ class OpAbsTest : public OperatorTest { Tensor& op_abs_out(const Tensor& self, Tensor& out) { return torch::executor::aten::abs_outf(context_, self, out); } + + template + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {-3, -2, -1, 0, 1, 2}); + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, 1, 2}); + + Tensor ret = op_abs_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } + + template <> + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {253, 254, 255, 0, 1, 2}); + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {253, 254, 255, 0, 1, 2}); + + Tensor ret = op_abs_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } }; +TEST_F(OpAbsTest, AllRealHBF16Input) { +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \ + test_dtype(); + + ET_FORALL_REALHBF16_TYPES(TEST_KERNEL); +#undef TEST_KERNEL +} + TEST_F(OpAbsTest, SanityCheck) { TensorFactory tf; diff --git a/kernels/test/op_full_test.cpp b/kernels/test/op_full_test.cpp index 09885ddd991..72bb60638e9 100644 --- a/kernels/test/op_full_test.cpp +++ b/kernels/test/op_full_test.cpp @@ -122,3 +122,26 @@ TEST_F(OpFullOutTest, ZeroDim) { op_full_out(sizes, true, out); EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec)); } + +TEST_F(OpFullOutTest, BFloat16Support) { + TensorFactory tf; + + std::vector sizes_int64_t_vec = {2, 3}; + std::vector sizes_in32_t_vec = {2, 3}; + auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size()); + + // Boolean Scalar + Tensor out = tf.zeros(sizes_in32_t_vec); + op_full_out(sizes, true, out); + EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec)); + + // Integral Scalar + out = tf.zeros(sizes_in32_t_vec); + op_full_out(sizes, 1, out); + EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec)); + + // Floating Point Scalar + out = tf.zeros(sizes_in32_t_vec); + op_full_out(sizes, 3.1415926535, out); + EXPECT_TENSOR_EQ(out, tf.full(sizes_in32_t_vec, 3.1415926535)); +} diff --git a/kernels/test/op_gelu_test.cpp b/kernels/test/op_gelu_test.cpp index 7155bfb1b7b..b57d1e4b46d 100644 --- a/kernels/test/op_gelu_test.cpp +++ b/kernels/test/op_gelu_test.cpp @@ -66,6 +66,10 @@ class OpGeluTest : public OperatorTest { } }; +TEST_F(OpGeluTest, HalfTensors) { + test_gelu_execution(); +} + TEST_F(OpGeluTest, FloatTensors) { test_gelu_execution(); } diff --git a/kernels/test/op_hardtanh_test.cpp b/kernels/test/op_hardtanh_test.cpp index bf790e432f9..d906f28b19e 100644 --- a/kernels/test/op_hardtanh_test.cpp +++ b/kernels/test/op_hardtanh_test.cpp @@ -30,8 +30,32 @@ class OpHardTanhTest : public OperatorTest { return torch::executor::aten::hardtanh_outf( context_, self, min_val, max_val, out); } + + template + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {0, 1, 2, 3, 4, 5}); + Scalar min_val = 1; + Scalar max_val = 4; + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {1, 1, 2, 3, 4, 4}); + + Tensor ret = op_hardtanh_out(in, min_val, max_val, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } }; +TEST_F(OpHardTanhTest, AllRealHBF16Input) { +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \ + test_dtype(); + + ET_FORALL_REALHBF16_TYPES(TEST_KERNEL); +#undef TEST_KERNEL +} + TEST_F(OpHardTanhTest, SanityCheck) { TensorFactory tf; Tensor in = tf.ones({2, 2}); diff --git a/kernels/test/op_logit_test.cpp b/kernels/test/op_logit_test.cpp index fda3fae4e88..4adb5645d75 100644 --- a/kernels/test/op_logit_test.cpp +++ b/kernels/test/op_logit_test.cpp @@ -100,7 +100,7 @@ void OpLogitOutTest:: TEST_F(OpLogitOutTest, AllRealInputFloatOutputSupport) { #define TEST_ENTRY(ctype, dtype) \ test_integer_logit_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_neg_test.cpp b/kernels/test/op_neg_test.cpp index 09bbb8b6af1..5d39e2fece6 100644 --- a/kernels/test/op_neg_test.cpp +++ b/kernels/test/op_neg_test.cpp @@ -24,8 +24,44 @@ class OpNegTest : public OperatorTest { Tensor& op_neg_out(const Tensor& self, Tensor& out) { return torch::executor::aten::neg_outf(context_, self, out); } + + template + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {-3, -2, -1, 0, 1, 2}); + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, -1, -2}); + + Tensor ret = op_neg_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } + + template <> + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {253, 254, 255, 0, 1, 2}); + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {3, 2, 1, 0, 255, 254}); + + Tensor ret = op_neg_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } }; +TEST_F(OpNegTest, AllRealHBF16Input) { +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \ + test_dtype(); + + ET_FORALL_REALHBF16_TYPES(TEST_KERNEL); +#undef TEST_KERNEL +} + TEST_F(OpNegTest, SanityCheck) { TensorFactory tf; diff --git a/kernels/test/op_sign_test.cpp b/kernels/test/op_sign_test.cpp index e411675b19b..de3dc1002bc 100644 --- a/kernels/test/op_sign_test.cpp +++ b/kernels/test/op_sign_test.cpp @@ -25,8 +25,44 @@ class OpSignTest : public OperatorTest { Tensor& op_sign_out(const Tensor& self, Tensor& out) { return torch::executor::aten::sign_outf(context_, self, out); } + + template + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {-3, -2, -1, 0, 1, 2}); + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {-1, -1, -1, 0, 1, 1}); + + Tensor ret = op_sign_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } + + template <> + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({2, 3}, {253, 254, 255, 0, 1, 2}); + Tensor out = tf.zeros({2, 3}); + Tensor expected = tf.make({2, 3}, {1, 1, 1, 0, 1, 1}); + + Tensor ret = op_sign_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } }; +TEST_F(OpSignTest, AllRealHBF16Input) { +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE) \ + test_dtype(); + + ET_FORALL_REALHBF16_TYPES(TEST_KERNEL); +#undef TEST_KERNEL +} + TEST_F(OpSignTest, ETSanityCheckFloat) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen returns 0 on NAN input";