From 70c4a0075b9ad76fb5ac1548107b1f500990c3c2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 14:32:55 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_pdist_forward.cpp | 2 +- kernels/test/op_pdist_forward_test.cpp | 87 ++++++++++++++--------- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/kernels/portable/cpu/op_pdist_forward.cpp b/kernels/portable/cpu/op_pdist_forward.cpp index 04217cc8eb4..1aa53e0cdd3 100644 --- a/kernels/portable/cpu/op_pdist_forward.cpp +++ b/kernels/portable/cpu/op_pdist_forward.cpp @@ -42,7 +42,7 @@ Tensor& _pdist_forward_out( ScalarType in_type = in.scalar_type(); constexpr auto name = "_pdist_forward.out"; - ET_SWITCH_FLOAT_TYPES( + ET_SWITCH_FLOATHBF16_TYPES( in_type, ctx, name, CTYPE, [&] { pdist(in, out, p); }); return out; diff --git a/kernels/test/op_pdist_forward_test.cpp b/kernels/test/op_pdist_forward_test.cpp index f022c9af94f..e8fa0f4c742 100644 --- a/kernels/test/op_pdist_forward_test.cpp +++ b/kernels/test/op_pdist_forward_test.cpp @@ -33,45 +33,68 @@ class OpPdistForwardOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } -}; -TEST_F(OpPdistForwardOutTest, SmokeTest) { - TensorFactory tfFloat; + template + void test_dtype() { + TensorFactory tf; + + Tensor in = tf.make({4, 5}, {0, 1, 2, 3, 5, 4, 3, 2, -1, 5, + 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); + Tensor out = tf.zeros({6}); - Tensor in = tfFloat.make( - {4, 5}, {0, 1, 2, 3, 5, 4, 3, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); - Tensor out = tfFloat.zeros({6}); + Tensor l0 = tf.make({6}, {3., 3., 3., 4., 0., 4.}); + op_pdist_forward_out(in, 0.0, out); + EXPECT_TENSOR_CLOSE(out, l0); - Tensor l0 = tfFloat.make({6}, {3., 3., 3., 4., 0., 4.}); - op_pdist_forward_out(in, 0.0, out); - EXPECT_TENSOR_CLOSE(out, l0); + Tensor l0p5 = tf.make( + {6}, + {29.31370926, 19.48528290, 29.31370926, 43.03986740, 0.0, 43.03986740}); + op_pdist_forward_out(in, 0.5, out); + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out, + l0p5, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out, l0p5); + } - Tensor l0p5 = tfFloat.make( - {6}, - {29.31370926, 19.48528290, 29.31370926, 43.03986740, 0.0, 43.03986740}); - op_pdist_forward_out(in, 0.5, out); - EXPECT_TENSOR_CLOSE(out, l0p5); + Tensor l1 = tf.make({6}, {10., 7., 10., 11., 0., 11.}); + op_pdist_forward_out(in, 1.0, out); + EXPECT_TENSOR_CLOSE(out, l1); - Tensor l1 = tfFloat.make({6}, {10., 7., 10., 11., 0., 11.}); - op_pdist_forward_out(in, 1.0, out); - EXPECT_TENSOR_CLOSE(out, l1); + Tensor l1p5 = tf.make( + {6}, {7.07743692, 5.19140196, 7.07743692, 7.08359480, 0.0, 7.08359480}); + op_pdist_forward_out(in, 1.5, out); + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out, + l1p5, + 1e-2, + executorch::runtime::testing::internal::kDefaultAtol); + } else { + EXPECT_TENSOR_CLOSE(out, l1p5); + } - Tensor l1p5 = tfFloat.make( - {6}, {7.07743692, 5.19140196, 7.07743692, 7.08359480, 0.0, 7.08359480}); - op_pdist_forward_out(in, 1.5, out); - EXPECT_TENSOR_CLOSE(out, l1p5); + Tensor l2 = + tf.make({6}, {6.0, 4.58257580, 6.0, 5.74456263, 0.0, 5.74456263}); + op_pdist_forward_out(in, 2.0, out); + EXPECT_TENSOR_CLOSE(out, l2); - Tensor l2 = - tfFloat.make({6}, {6.0, 4.58257580, 6.0, 5.74456263, 0.0, 5.74456263}); - op_pdist_forward_out(in, 2.0, out); - EXPECT_TENSOR_CLOSE(out, l2); + Tensor l3 = tf.make( + {6}, {5.14256334, 4.17933941, 5.14256334, 4.74745941, 0.0, 4.74745941}); + op_pdist_forward_out(in, 3.0, out); + EXPECT_TENSOR_CLOSE(out, l3); - Tensor l3 = tfFloat.make( - {6}, {5.14256334, 4.17933941, 5.14256334, 4.74745941, 0.0, 4.74745941}); - op_pdist_forward_out(in, 3.0, out); - EXPECT_TENSOR_CLOSE(out, l3); + Tensor linf = tf.make({6}, {4., 4., 4., 4., 0., 4.}); + op_pdist_forward_out(in, INFINITY, out); + EXPECT_TENSOR_CLOSE(out, linf); + } +}; - Tensor linf = tfFloat.make({6}, {4., 4., 4., 4., 0., 4.}); - op_pdist_forward_out(in, INFINITY, out); - EXPECT_TENSOR_CLOSE(out, linf); +TEST_F(OpPdistForwardOutTest, SmokeTest) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY) +#undef TEST_ENTRY }