diff --git a/kernels/portable/cpu/op_log_softmax.cpp b/kernels/portable/cpu/op_log_softmax.cpp index 096fb4ff9c1..cbe5f2139fd 100644 --- a/kernels/portable/cpu/op_log_softmax.cpp +++ b/kernels/portable/cpu/op_log_softmax.cpp @@ -42,7 +42,7 @@ Tensor& log_softmax_out( // Adjust for negative dim dim = dim < 0 ? dim + nonzero_dim(in) : dim; - ET_SWITCH_FLOAT_TYPES( + ET_SWITCH_FLOATHBF16_TYPES( in.scalar_type(), ctx, "_log_softmax.out", CTYPE, [&]() { const CTYPE* const in_data = in.const_data_ptr(); CTYPE* const out_data = out.mutable_data_ptr(); diff --git a/kernels/test/op_log_softmax_test.cpp b/kernels/test/op_log_softmax_test.cpp index 6efaa1c08ed..ca1f5e7ae64 100644 --- a/kernels/test/op_log_softmax_test.cpp +++ b/kernels/test/op_log_softmax_test.cpp @@ -62,7 +62,15 @@ class OpLogSoftmaxOutTest : public OperatorTest { }); // clang-format on - EXPECT_TENSOR_CLOSE(out, expected); + if constexpr (DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out, expected); + } } }; @@ -88,11 +96,9 @@ TEST_F(OpLogSoftmaxOutTest, AllDtypesSupported) { GTEST_SKIP() << "This kernel does not support dtype double"; } - test_dtype(); - test_dtype(); - // TODO: Also add tests for half, complex, quantized, and other types. Easiest - // way to do that would be to make TensorFactory support zeros() and ones() - // for those types. +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY } TEST_F(OpLogSoftmaxOutTest, MismatchedDimensionsDies) {