From f538e6c3acbc0d3d02dadc1af8fc59a6dbf8b4a4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 21 Jan 2025 11:31:56 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_cdist_forward.cpp | 2 +- kernels/test/op_cdist_forward_test.cpp | 170 ++++++++++-------- .../exec_aten/testing_util/tensor_util.cpp | 3 + 3 files changed, 97 insertions(+), 78 deletions(-) diff --git a/kernels/portable/cpu/op_cdist_forward.cpp b/kernels/portable/cpu/op_cdist_forward.cpp index 657de86ac1a..4b4c9a154f1 100644 --- a/kernels/portable/cpu/op_cdist_forward.cpp +++ b/kernels/portable/cpu/op_cdist_forward.cpp @@ -162,7 +162,7 @@ Tensor& _cdist_forward_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "_cdist_forward.out"; - ET_SWITCH_FLOAT_TYPES( + ET_SWITCH_FLOATHBF16_TYPES( out_type, ctx, name, CTYPE, [&] { cdist(x1, x2, out, p); }); return out; diff --git a/kernels/test/op_cdist_forward_test.cpp b/kernels/test/op_cdist_forward_test.cpp index c8c18c36add..2436c448f82 100644 --- a/kernels/test/op_cdist_forward_test.cpp +++ b/kernels/test/op_cdist_forward_test.cpp @@ -40,89 +40,105 @@ class OpCdistForwardOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } -}; -TEST_F(OpCdistForwardOutTest, SmokeTest) { - TensorFactory tfFloat; + template + void test_dtype() { + TensorFactory tf; + + Tensor x1 = tf.make({2, 1, 4, 3}, {0, 1, 2, 3, 5, 4, 3, -3, 7, 1, 6, 2, + -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); + Tensor x2 = tf.make( + {1, 2, 5, 3}, {0, 1, 2, 3, 5, -3, 7, 1, 6, 2, -1, 5, 1, 1, -2, + 4, 3, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); + optional compute_mode = optional(); - Tensor x1 = - tfFloat.make({2, 1, 4, 3}, {0, 1, 2, 3, 5, 4, 3, -3, 7, 1, 6, 2, - -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); - Tensor x2 = tfFloat.make( - {1, 2, 5, 3}, {0, 1, 2, 3, 5, -3, 7, 1, 6, 2, -1, 5, 1, 1, -2, - 4, 3, 2, -1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5}); - optional compute_mode = optional(); + Tensor out = tf.zeros({2, 2, 4, 5}); - Tensor out = tfFloat.zeros({2, 2, 4, 5}); + Tensor l0 = tf.make( + {2, 2, 4, 5}, + {0., 3., 2., 3., 2., 3., 1., 3., 3., 3., 3., 2., 3., 3., 3., 2., + 3., 3., 3., 2., 2., 3., 3., 3., 3., 3., 2., 3., 3., 3., 3., 3., + 3., 3., 3., 2., 3., 2., 3., 3., 3., 2., 3., 3., 3., 3., 3., 3., + 3., 2., 3., 3., 3., 3., 3., 3., 3., 3., 0., 3., 3., 0., 2., 3., + 3., 3., 2., 0., 3., 3., 3., 3., 3., 0., 3., 3., 3., 3., 3., 0.}); + op_cdist_forward_out(x1, x2, 0.0, compute_mode, out); + EXPECT_TENSOR_CLOSE(out, l0); - Tensor l0 = tfFloat.make( - {2, 2, 4, 5}, - {0., 3., 2., 3., 2., 3., 1., 3., 3., 3., 3., 2., 3., 3., 3., 2., - 3., 3., 3., 2., 2., 3., 3., 3., 3., 3., 2., 3., 3., 3., 3., 3., - 3., 3., 3., 2., 3., 2., 3., 3., 3., 2., 3., 3., 3., 3., 3., 3., - 3., 2., 3., 3., 3., 3., 3., 3., 3., 3., 0., 3., 3., 0., 2., 3., - 3., 3., 2., 0., 3., 3., 3., 3., 3., 0., 3., 3., 3., 3., 3., 0.}); - op_cdist_forward_out(x1, x2, 0.0, compute_mode, out); - EXPECT_TENSOR_CLOSE(out, l0); + Tensor l1 = tf.make( + {2, 2, 4, 5}, + {0., 12., 11., 7., 5., 9., 7., 10., 8., 12., 12., 18., 9., 5., + 15., 6., 8., 15., 11., 9., 6., 6., 5., 9., 7., 5., 7., 12., + 4., 8., 12., 18., 9., 13., 5., 6., 4., 9., 7., 11., 6., 8., + 17., 13., 9., 5., 13., 14., 6., 6., 9., 9., 8., 10., 12., 7., + 15., 8., 0., 10., 8., 0., 9., 9., 13., 9., 9., 0., 12., 6., + 3., 9., 12., 0., 10., 9., 13., 6., 10., 0.}); + op_cdist_forward_out(x1, x2, 1.0, compute_mode, out); + EXPECT_TENSOR_CLOSE(out, l1); - Tensor l1 = tfFloat.make( - {2, 2, 4, 5}, - {0., 12., 11., 7., 5., 9., 7., 10., 8., 12., 12., 18., 9., 5., - 15., 6., 8., 15., 11., 9., 6., 6., 5., 9., 7., 5., 7., 12., - 4., 8., 12., 18., 9., 13., 5., 6., 4., 9., 7., 11., 6., 8., - 17., 13., 9., 5., 13., 14., 6., 6., 9., 9., 8., 10., 12., 7., - 15., 8., 0., 10., 8., 0., 9., 9., 13., 9., 9., 0., 12., 6., - 3., 9., 12., 0., 10., 9., 13., 6., 10., 0.}); - op_cdist_forward_out(x1, x2, 1.0, compute_mode, out); - EXPECT_TENSOR_CLOSE(out, l1); + Tensor l2 = tf.make( + {2, 2, 4, 5}, + {0.00000000, 7.07106781, 8.06225777, 4.12310553, 4.12310553, + 5.38516474, 7.00000000, 6.00000000, 6.16441393, 7.48331499, + 7.07106781, 12.80624866, 5.74456263, 3.00000000, 10.04987526, + 5.09901953, 5.47722578, 8.77496433, 7.68114567, 6.40312433, + 4.47213602, 4.24264050, 3.31662488, 5.91608000, 4.12310553, + 3.00000000, 5.00000000, 7.87400770, 2.44948983, 6.16441393, + 7.87400770, 10.77032948, 6.40312433, 8.30662346, 3.00000000, + 4.24264050, 2.44948983, 8.06225777, 4.58257580, 7.68114567, + 4.24264050, 5.65685415, 10.24695110, 7.81024981, 5.38516474, + 3.31662488, 8.30662346, 8.36660004, 4.24264050, 4.24264050, + 5.91608000, 6.40312433, 4.69041586, 6.16441393, 7.07106781, + 4.12310553, 10.04987526, 5.47722578, 0.00000000, 7.34846926, + 5.47722578, 0.00000000, 7.28010988, 6.40312433, 7.81024981, + 5.91608000, 7.28010988, 0.00000000, 7.48331499, 4.24264050, + 1.73205078, 6.40312433, 7.48331499, 0.00000000, 6.16441393, + 5.38516474, 7.81024981, 4.24264050, 6.16441393, 0.00000000}); + op_cdist_forward_out(x1, x2, 2.0, compute_mode, out); + EXPECT_TENSOR_CLOSE(out, l2); - Tensor l2 = tfFloat.make( - {2, 2, 4, 5}, - {0.00000000, 7.07106781, 8.06225777, 4.12310553, 4.12310553, - 5.38516474, 7.00000000, 6.00000000, 6.16441393, 7.48331499, - 7.07106781, 12.80624866, 5.74456263, 3.00000000, 10.04987526, - 5.09901953, 5.47722578, 8.77496433, 7.68114567, 6.40312433, - 4.47213602, 4.24264050, 3.31662488, 5.91608000, 4.12310553, - 3.00000000, 5.00000000, 7.87400770, 2.44948983, 6.16441393, - 7.87400770, 10.77032948, 6.40312433, 8.30662346, 3.00000000, - 4.24264050, 2.44948983, 8.06225777, 4.58257580, 7.68114567, - 4.24264050, 5.65685415, 10.24695110, 7.81024981, 5.38516474, - 3.31662488, 8.30662346, 8.36660004, 4.24264050, 4.24264050, - 5.91608000, 6.40312433, 4.69041586, 6.16441393, 7.07106781, - 4.12310553, 10.04987526, 5.47722578, 0.00000000, 7.34846926, - 5.47722578, 0.00000000, 7.28010988, 6.40312433, 7.81024981, - 5.91608000, 7.28010988, 0.00000000, 7.48331499, 4.24264050, - 1.73205078, 6.40312433, 7.48331499, 0.00000000, 6.16441393, - 5.38516474, 7.81024981, 4.24264050, 6.16441393, 0.00000000}); - op_cdist_forward_out(x1, x2, 2.0, compute_mode, out); - EXPECT_TENSOR_CLOSE(out, l2); + Tensor l3 = tf.make( + {2, 2, 4, 5}, + {0.00000000, 6.00000000, 7.41079521, 3.50339794, 4.02072573, + 4.62606478, 7.00000000, 5.14256334, 6.01846170, 6.60385466, + 6.00000000, 11.47758675, 5.05277443, 2.57128167, 9.28704357, + 5.01329803, 5.11722994, 7.39863634, 7.18551636, 5.73879337, + 4.16016769, 4.04124022, 3.07231688, 5.34848118, 3.50339794, + 2.57128167, 4.49794149, 7.23042679, 2.15443468, 6.01846170, + 6.99319077, 9.25212955, 6.08220196, 7.45903587, 2.57128167, + 3.77976322, 2.15443468, 8.00520515, 4.17933941, 7.18551636, + 4.04124022, 5.03968430, 8.88326645, 6.74599648, 4.62606478, + 3.07231688, 7.45903587, 7.16609573, 4.04124022, 3.77976322, + 5.34848118, 6.08220196, 3.95789170, 5.42883539, 6.00000000, + 3.50339794, 9.00000000, 5.11722994, 0.00000000, 7.06069660, + 5.11722994, 0.00000000, 7.05400419, 6.08220196, 6.74599648, + 5.34848118, 7.05400419, 0.00000000, 6.60385466, 4.04124022, + 1.44224954, 6.08220196, 6.60385466, 0.00000000, 5.42883539, + 4.62606478, 6.74599648, 4.04124022, 5.42883539, 0.00000000}); + op_cdist_forward_out(x1, x2, 3.0, compute_mode, out); + if (DTYPE == ScalarType::BFloat16) { + EXPECT_TENSOR_CLOSE_WITH_TOL( + out, + l3, + 1e-2, + executorch::runtime::testing::internal::kDefaultBFloat16Atol); + } else { + EXPECT_TENSOR_CLOSE(out, l3); + } - Tensor l3 = tfFloat.make( - {2, 2, 4, 5}, - {0.00000000, 6.00000000, 7.41079521, 3.50339794, 4.02072573, 4.62606478, - 7.00000000, 5.14256334, 6.01846170, 6.60385466, 6.00000000, 11.47758675, - 5.05277443, 2.57128167, 9.28704357, 5.01329803, 5.11722994, 7.39863634, - 7.18551636, 5.73879337, 4.16016769, 4.04124022, 3.07231688, 5.34848118, - 3.50339794, 2.57128167, 4.49794149, 7.23042679, 2.15443468, 6.01846170, - 6.99319077, 9.25212955, 6.08220196, 7.45903587, 2.57128167, 3.77976322, - 2.15443468, 8.00520515, 4.17933941, 7.18551636, 4.04124022, 5.03968430, - 8.88326645, 6.74599648, 4.62606478, 3.07231688, 7.45903587, 7.16609573, - 4.04124022, 3.77976322, 5.34848118, 6.08220196, 3.95789170, 5.42883539, - 6.00000000, 3.50339794, 9.00000000, 5.11722994, 0.00000000, 7.06069660, - 5.11722994, 0.00000000, 7.05400419, 6.08220196, 6.74599648, 5.34848118, - 7.05400419, 0.00000000, 6.60385466, 4.04124022, 1.44224954, 6.08220196, - 6.60385466, 0.00000000, 5.42883539, 4.62606478, 6.74599648, 4.04124022, - 5.42883539, 0.00000000}); - op_cdist_forward_out(x1, x2, 3.0, compute_mode, out); - EXPECT_TENSOR_CLOSE(out, l3); + Tensor linf = tf.make( + {2, 2, 4, 5}, + {0., 5., 7., 3., 4., 4., 7., 4., 6., 6., 5., 10., 4., 2., 9., 5., + 5., 6., 7., 5., 4., 4., 3., 5., 3., 2., 4., 7., 2., 6., 6., 8., + 6., 7., 2., 3., 2., 8., 4., 7., 4., 4., 8., 6., 4., 3., 7., 6., + 4., 3., 5., 6., 3., 5., 5., 3., 8., 5., 0., 7., 5., 0., 7., 6., + 6., 5., 7., 0., 6., 4., 1., 6., 6., 0., 5., 4., 6., 4., 5., 0.}); + op_cdist_forward_out(x1, x2, INFINITY, compute_mode, out); + EXPECT_TENSOR_CLOSE(out, linf); + } +}; - Tensor linf = tfFloat.make( - {2, 2, 4, 5}, - {0., 5., 7., 3., 4., 4., 7., 4., 6., 6., 5., 10., 4., 2., 9., 5., - 5., 6., 7., 5., 4., 4., 3., 5., 3., 2., 4., 7., 2., 6., 6., 8., - 6., 7., 2., 3., 2., 8., 4., 7., 4., 4., 8., 6., 4., 3., 7., 6., - 4., 3., 5., 6., 3., 5., 5., 3., 8., 5., 0., 7., 5., 0., 7., 6., - 6., 5., 7., 0., 6., 4., 1., 6., 6., 0., 5., 4., 6., 4., 5., 0.}); - op_cdist_forward_out(x1, x2, INFINITY, compute_mode, out); - EXPECT_TENSOR_CLOSE(out, linf); +TEST_F(OpCdistForwardOutTest, SmokeTest) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY } diff --git a/runtime/core/exec_aten/testing_util/tensor_util.cpp b/runtime/core/exec_aten/testing_util/tensor_util.cpp index f1c25eb9fe6..0e97c3c245c 100644 --- a/runtime/core/exec_aten/testing_util/tensor_util.cpp +++ b/runtime/core/exec_aten/testing_util/tensor_util.cpp @@ -80,6 +80,9 @@ double default_atol_for_type(ScalarType t) { if (t == ScalarType::Half) { return internal::kDefaultHalfAtol; } + if (t == ScalarType::BFloat16) { + return internal::kDefaultBFloat16Atol; + } return internal::kDefaultAtol; } } // namespace