From 99ea1ab942db33547f8d8699a6ecc6b729a5b8a9 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 30 Jan 2025 16:19:48 -0800 Subject: [PATCH] Fix ATen mode op_logit_test Was broken, now it's not. Differential Revision: [D68929577](https://our.internmc.facebook.com/intern/diff/D68929577/) [ghstack-poisoned] --- kernels/test/op_logit_test.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/kernels/test/op_logit_test.cpp b/kernels/test/op_logit_test.cpp index a027bc547af..a5e78e8aa6b 100644 --- a/kernels/test/op_logit_test.cpp +++ b/kernels/test/op_logit_test.cpp @@ -57,10 +57,17 @@ class OpLogitOutTest : public OperatorTest { op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0.1, out); - // Check that it matches (or close to) the expected output. - EXPECT_TENSOR_CLOSE( - out, - tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224})); + auto expected = + tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224}); + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out, + expected, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out, expected); + } } // Unhandled output dtypes.