From 9162daabe482042fffd55abf952bb5920c30c19d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 23 Jan 2025 10:18:54 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_gelu.cpp | 2 +- kernels/test/op_gelu_test.cpp | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/kernels/portable/cpu/op_gelu.cpp b/kernels/portable/cpu/op_gelu.cpp index db5d9cbfe71..14468923a5b 100644 --- a/kernels/portable/cpu/op_gelu.cpp +++ b/kernels/portable/cpu/op_gelu.cpp @@ -37,7 +37,7 @@ Tensor& gelu_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() { + ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, "gelu.out", CTYPE, [&]() { if (approximate == "tanh") { apply_unary_map_fn( [](const CTYPE x) { diff --git a/kernels/test/op_gelu_test.cpp b/kernels/test/op_gelu_test.cpp index 7155bfb1b7b..3334b9acffb 100644 --- a/kernels/test/op_gelu_test.cpp +++ b/kernels/test/op_gelu_test.cpp @@ -70,6 +70,14 @@ TEST_F(OpGeluTest, FloatTensors) { test_gelu_execution(); } +TEST_F(OpGeluTest, HalfTensors) { + test_gelu_execution(); +} + +TEST_F(OpGeluTest, BFloat16Tensors) { + test_gelu_execution(); +} + TEST_F(OpGeluTest, DoubleTensors) { if (!SupportedFeatures::get()->op_gelu_dtype_double) { GTEST_SKIP();