diff --git a/kernels/test/op_logit_test.cpp b/kernels/test/op_logit_test.cpp index a027bc547af..a5e78e8aa6b 100644 --- a/kernels/test/op_logit_test.cpp +++ b/kernels/test/op_logit_test.cpp @@ -57,10 +57,17 @@ class OpLogitOutTest : public OperatorTest { op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0.1, out); - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224})); + auto expected = + tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224}); + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out, expected); + } } // Unhandled output dtypes.