diff --git a/kernels/portable/cpu/op_hardtanh.cpp b/kernels/portable/cpu/op_hardtanh.cpp index e86edab76b4..56ac77b37fb 100644 --- a/kernels/portable/cpu/op_hardtanh.cpp +++ b/kernels/portable/cpu/op_hardtanh.cpp @@ -46,7 +46,7 @@ Tensor& hardtanh_out( ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); - ET_SWITCH_REAL_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { + ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { CTYPE min_casted; ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() { CTYPE_MIN min_val; diff --git a/kernels/portable/cpu/util/math_util.h b/kernels/portable/cpu/util/math_util.h index 05935fff389..e6cee5eec77 100644 --- a/kernels/portable/cpu/util/math_util.h +++ b/kernels/portable/cpu/util/math_util.h @@ -96,8 +96,10 @@ INT_T max_override(INT_T a, INT_T b) { template < typename T, - typename std::enable_if::value, bool>:: - type = true> + typename std::enable_if_t< + std::is_same_v || + std::is_same_v, + bool> = true> T min_override(T a, T b) { const auto float_a = static_cast(a); if (std::isnan(float_a)) { @@ -116,8 +118,10 @@ T min_override(T a, T b) { template < typename T, - typename std::enable_if::value, bool>:: - type = true> + typename std::enable_if_t< + std::is_same_v || + std::is_same_v, + bool> = true> T max_override(T a, T b) { const auto float_a = static_cast(a); if (std::isnan(float_a)) { diff --git a/kernels/test/op_hardtanh_test.cpp b/kernels/test/op_hardtanh_test.cpp index bf790e432f9..ba60a3e39f6 100644 --- a/kernels/test/op_hardtanh_test.cpp +++ b/kernels/test/op_hardtanh_test.cpp @@ -30,15 +30,31 @@ class OpHardTanhTest : public OperatorTest { return torch::executor::aten::hardtanh_outf( context_, self, min_val, max_val, out); } -}; -TEST_F(OpHardTanhTest, SanityCheck) { - TensorFactory tf; - Tensor in = tf.ones({2, 2}); - Tensor out = tf.zeros({2, 2}); + template + void test_dtype() { + TensorFactory tf; + CTYPE lowest_test_element; + CTYPE lower_bound; + if constexpr (std::numeric_limits::is_signed) { + lowest_test_element = -3; + lower_bound = -2; + } else { + lowest_test_element = 0; + lower_bound = 0; + } + Tensor in = tf.make({2, 2}, {lowest_test_element, 0, 1, 100}); + Tensor out = tf.zeros({2, 2}); + + Tensor ret = op_hardtanh_out(in, lower_bound, 2, out); - Tensor ret = op_hardtanh_out(in, -2, 2, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {lower_bound, 0, 1, 2})); + } +}; - EXPECT_TENSOR_EQ(out, ret); - EXPECT_TENSOR_EQ(out, tf.ones({2, 2})); +TEST_F(OpHardTanhTest, SanityCheck) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY }