From 25243a9f04730c9710936d63cc9ccce7fc760af3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Wed, 31 Jul 2024 10:36:05 +0200 Subject: [PATCH 1/7] added count to mhp algorithms --- include/dr/mhp.hpp | 1 + include/dr/mhp/algorithms/count.hpp | 162 ++++++++++++++++++++++++++++ test/gtest/common/count.cpp | 38 +++++++ 3 files changed, 201 insertions(+) create mode 100644 include/dr/mhp/algorithms/count.hpp create mode 100644 test/gtest/common/count.cpp diff --git a/include/dr/mhp.hpp b/include/dr/mhp.hpp index b13d5007ef..8bad2caeee 100644 --- a/include/dr/mhp.hpp +++ b/include/dr/mhp.hpp @@ -75,5 +75,6 @@ #include #include #include +#include #include #include diff --git a/include/dr/mhp/algorithms/count.hpp b/include/dr/mhp/algorithms/count.hpp new file mode 100644 index 0000000000..3ec838df76 --- /dev/null +++ b/include/dr/mhp/algorithms/count.hpp @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Intel Corporation +// +// SPDX-License-Identifier: BSD-3-Clause + +#pragma once + +namespace dr::mhp::__detail { + +inline auto add_counts(rng::forward_range auto &&r) { + rng::range_difference_t zero{}; + + return std::accumulate(rng::begin(r), rng::end(r), zero); +} + +inline auto std_count_if(rng::forward_range auto &&r, auto &&pred) { + using count_type = rng::range_difference_t; + + if (rng::empty(r)) { + return count_type{}; + } + + return std::count_if(std::execution::par_unseq, + dr::__detail::direct_iterator(rng::begin(r)), + dr::__detail::direct_iterator(rng::end(r)), + pred); +} + +inline auto dpl_count_if(rng::forward_range auto &&r, auto &&pred) { + using count_type = rng::range_difference_t; + +#ifdef SYCL_LANGUAGE_VERSION + if (rng::empty(r)) { + return count_type{}; + } + + return std::count_if(dpl_policy(), + dr::__detail::direct_iterator(rng::begin(r)), + dr::__detail::direct_iterator(rng::end(r)), + pred); +#else + assert(false); + return count_type{}; +#endif +} + +template +auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { + using count_type = rng::range_difference_t; + auto comm = default_comm(); + + if (rng::empty(dr)) { + return count_type{}; + } + + if (aligned(dr)) { + dr::drlog.debug("Parallel count\n"); + + // Count within the local segments + auto count = [=](auto &&r) { + assert(rng::size(r) > 0); + if (mhp::use_sycl()) { + dr::drlog.debug(" with DPL\n"); + return dpl_count_if(r, pred); + } else { + dr::drlog.debug(" with CPU\n"); + return std_count_if(r, pred); + } + }; + + auto locals = rng::views::transform(local_segments(dr), count); + auto local = add_counts(locals); + + std::vector all(comm.size()); + if (root_provided) { + // Everyone gathers to root, only root adds up the counts + comm.gather(local, std::span{all}, root); + if (root == comm.rank()) { + return add_counts(all); + } else { + return count_type{}; + } + } else { + // Everyone gathers and everyone adds up the counts + comm.all_gather(local, all); + return add_counts(all); + } + } else { + dr::drlog.debug("Serial count\n"); + count_type result{}; + if (!root_provided || root == comm.rank()) { + result = add_counts(dr); + } + barrier(); + return result; + } +} + +} // namespace dr::mhp::__detail + +namespace dr::mhp { + +// +// Ranges +// + +// range, elem, w/wo root + +template +auto count(std::size_t root, DR &&dr, const T& value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, dr, pred); +} + +template +auto count(DR &&dr, const T& value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, dr, pred); +} + +// range, predicate, w/wo root + +template +auto count_if(std::size_t root, DR &&dr, auto &&pred) { + return __detail::count_if(root, true, dr, pred); +} + +template +auto count_if(DR &&dr, auto &&pred) { + return __detail::count_if(0, false, dr, pred); +} + +// +// Iterators +// + +// range, elem, w/wo root + +template +auto count(std::size_t root, DI first, DI last, const T& value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, rng::subrange(first, last), pred); +} + +template +auto count(DI first, DI last, const T& value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, rng::subrange(first, last), pred); +} + +// range, predicate, w/wo root + +template +auto count_if(std::size_t root, DI first, DI last, auto &&pred) { + return __detail::count_if(root, true, rng::subrange(first, last), pred); +} + +template +auto count_if(DI first, DI last, auto &&pred) { + return __detail::count_if(0, false, rng::subrange(first, last), pred); +} + +}; // namespace dr::mhp diff --git a/test/gtest/common/count.cpp b/test/gtest/common/count.cpp new file mode 100644 index 0000000000..371f33f8c9 --- /dev/null +++ b/test/gtest/common/count.cpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: Intel Corporation +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "xhp-tests.hpp" + +// Fixture +template class Count : public testing::Test { +protected: +}; + +TYPED_TEST_SUITE(Count, AllTypes); + +TYPED_TEST(Count, BasicFirstElem) { + Ops1 ops(10); + auto value = *ops.vec.begin(); + + EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value), + xhp::count(ops.dist_vec, value)); +} + +TYPED_TEST(Count, BasicFirstElemIf) { + Ops1 ops(10); + auto value = *ops.vec.begin(); + auto pred = [=](auto &&v) { v == value; } + + EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), + xhp::count_if(ops.dist_vec, pred)); +} + +TYPED_TEST(Count, FirstElemsIf) { + Ops1 ops(10); + auto value = *ops.vec.begin(); + auto pred = [=](auto &&v) { v < 5; } + + EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), + xhp::count_if(ops.dist_vec, pred)); +} From 7eec868759d304352e7b79052ae8b92d81733a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Wed, 31 Jul 2024 10:44:27 +0200 Subject: [PATCH 2/7] minor fix --- include/dr/mp.hpp | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/include/dr/mp.hpp b/include/dr/mp.hpp index 4fa7001089..f9598bbcd8 100644 --- a/include/dr/mp.hpp +++ b/include/dr/mp.hpp @@ -52,34 +52,6 @@ #include #include -<<<<<<< HEAD:include/dr/mhp.hpp -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -======= #include #include #include @@ -93,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -106,4 +79,3 @@ #include #include #include ->>>>>>> a9468e93d71f48ca7b977472fd01d63965c21d75:include/dr/mp.hpp From 6090fdcb1dc15e45dd9b0638412fd060b8ad7dbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Wed, 31 Jul 2024 11:15:57 +0200 Subject: [PATCH 3/7] minor fixes --- include/dr/mp/algorithms/count.hpp | 63 ++++++++++++++---------------- test/gtest/common/count.cpp | 13 +++--- test/gtest/mp/CMakeLists.txt | 3 +- 3 files changed, 38 insertions(+), 41 deletions(-) diff --git a/include/dr/mp/algorithms/count.hpp b/include/dr/mp/algorithms/count.hpp index 3ec838df76..75d438d8f0 100644 --- a/include/dr/mp/algorithms/count.hpp +++ b/include/dr/mp/algorithms/count.hpp @@ -4,11 +4,11 @@ #pragma once -namespace dr::mhp::__detail { +namespace dr::mp::__detail { inline auto add_counts(rng::forward_range auto &&r) { rng::range_difference_t zero{}; - + return std::accumulate(rng::begin(r), rng::end(r), zero); } @@ -19,10 +19,9 @@ inline auto std_count_if(rng::forward_range auto &&r, auto &&pred) { return count_type{}; } - return std::count_if(std::execution::par_unseq, + return std::count_if(std::execution::par_unseq, dr::__detail::direct_iterator(rng::begin(r)), - dr::__detail::direct_iterator(rng::end(r)), - pred); + dr::__detail::direct_iterator(rng::end(r)), pred); } inline auto dpl_count_if(rng::forward_range auto &&r, auto &&pred) { @@ -33,10 +32,9 @@ inline auto dpl_count_if(rng::forward_range auto &&r, auto &&pred) { return count_type{}; } - return std::count_if(dpl_policy(), + return std::count_if(mp::dpl_policy(), dr::__detail::direct_iterator(rng::begin(r)), - dr::__detail::direct_iterator(rng::end(r)), - pred); + dr::__detail::direct_iterator(rng::end(r)), pred); #else assert(false); return count_type{}; @@ -46,7 +44,7 @@ inline auto dpl_count_if(rng::forward_range auto &&r, auto &&pred) { template auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { using count_type = rng::range_difference_t; - auto comm = default_comm(); + auto comm = mp::default_comm(); if (rng::empty(dr)) { return count_type{}; @@ -58,7 +56,7 @@ auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { // Count within the local segments auto count = [=](auto &&r) { assert(rng::size(r) > 0); - if (mhp::use_sycl()) { + if (mp::use_sycl()) { dr::drlog.debug(" with DPL\n"); return dpl_count_if(r, pred); } else { @@ -90,14 +88,14 @@ auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { if (!root_provided || root == comm.rank()) { result = add_counts(dr); } - barrier(); + mp::barrier(); return result; } } - -} // namespace dr::mhp::__detail -namespace dr::mhp { +} // namespace dr::mp::__detail + +namespace dr::mp { // // Ranges @@ -106,27 +104,26 @@ namespace dr::mhp { // range, elem, w/wo root template -auto count(std::size_t root, DR &&dr, const T& value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(root, true, dr, pred); +auto count(std::size_t root, DR &&dr, const T &value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, dr, pred); } template -auto count(DR &&dr, const T& value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(0, false, dr, pred); +auto count(DR &&dr, const T &value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, dr, pred); } // range, predicate, w/wo root template auto count_if(std::size_t root, DR &&dr, auto &&pred) { - return __detail::count_if(root, true, dr, pred); + return __detail::count_if(root, true, dr, pred); } -template -auto count_if(DR &&dr, auto &&pred) { - return __detail::count_if(0, false, dr, pred); +template auto count_if(DR &&dr, auto &&pred) { + return __detail::count_if(0, false, dr, pred); } // @@ -136,27 +133,27 @@ auto count_if(DR &&dr, auto &&pred) { // range, elem, w/wo root template -auto count(std::size_t root, DI first, DI last, const T& value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(root, true, rng::subrange(first, last), pred); +auto count(std::size_t root, DI first, DI last, const T &value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, rng::subrange(first, last), pred); } template -auto count(DI first, DI last, const T& value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(0, false, rng::subrange(first, last), pred); +auto count(DI first, DI last, const T &value) { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, rng::subrange(first, last), pred); } // range, predicate, w/wo root template auto count_if(std::size_t root, DI first, DI last, auto &&pred) { - return __detail::count_if(root, true, rng::subrange(first, last), pred); + return __detail::count_if(root, true, rng::subrange(first, last), pred); } template auto count_if(DI first, DI last, auto &&pred) { - return __detail::count_if(0, false, rng::subrange(first, last), pred); + return __detail::count_if(0, false, rng::subrange(first, last), pred); } -}; // namespace dr::mhp +}; // namespace dr::mp diff --git a/test/gtest/common/count.cpp b/test/gtest/common/count.cpp index 371f33f8c9..6727c0bae3 100644 --- a/test/gtest/common/count.cpp +++ b/test/gtest/common/count.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: BSD-3-Clause -#include "xhp-tests.hpp" +#include "xp-tests.hpp" // Fixture template class Count : public testing::Test { @@ -16,23 +16,22 @@ TYPED_TEST(Count, BasicFirstElem) { auto value = *ops.vec.begin(); EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value), - xhp::count(ops.dist_vec, value)); + xp::count(ops.dist_vec, value)); } TYPED_TEST(Count, BasicFirstElemIf) { Ops1 ops(10); auto value = *ops.vec.begin(); - auto pred = [=](auto &&v) { v == value; } + auto pred = [=](auto &&v) { return v == value; }; EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), - xhp::count_if(ops.dist_vec, pred)); + xp::count_if(ops.dist_vec, pred)); } TYPED_TEST(Count, FirstElemsIf) { Ops1 ops(10); - auto value = *ops.vec.begin(); - auto pred = [=](auto &&v) { v < 5; } + auto pred = [=](auto &&v) { return v < 5; }; EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), - xhp::count_if(ops.dist_vec, pred)); + xp::count_if(ops.dist_vec, pred)); } diff --git a/test/gtest/mp/CMakeLists.txt b/test/gtest/mp/CMakeLists.txt index cef65af431..8cefdecc48 100644 --- a/test/gtest/mp/CMakeLists.txt +++ b/test/gtest/mp/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable( mp-tests.cpp ../common/all.cpp ../common/copy.cpp + ../common/count.cpp ../common/counted.cpp ../common/distributed_vector.cpp ../common/drop.cpp @@ -57,7 +58,7 @@ add_executable( add_executable(mp-quick-test mp-tests.cpp - ../common/equal.cpp + ../common/count.cpp ) # cmake-format: on From 167702dc3003a51ca82be38df7aebb7ec35b5c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Fri, 2 Aug 2024 19:42:45 +0200 Subject: [PATCH 4/7] code review fixes --- include/dr/mp/algorithms/count.hpp | 140 ++++++++++++----------------- test/gtest/common/count.cpp | 18 +++- 2 files changed, 73 insertions(+), 85 deletions(-) diff --git a/include/dr/mp/algorithms/count.hpp b/include/dr/mp/algorithms/count.hpp index 75d438d8f0..e5f606bcd4 100644 --- a/include/dr/mp/algorithms/count.hpp +++ b/include/dr/mp/algorithms/count.hpp @@ -12,33 +12,23 @@ inline auto add_counts(rng::forward_range auto &&r) { return std::accumulate(rng::begin(r), rng::end(r), zero); } -inline auto std_count_if(rng::forward_range auto &&r, auto &&pred) { - using count_type = rng::range_difference_t; - - if (rng::empty(r)) { - return count_type{}; - } - - return std::count_if(std::execution::par_unseq, - dr::__detail::direct_iterator(rng::begin(r)), - dr::__detail::direct_iterator(rng::end(r)), pred); -} - -inline auto dpl_count_if(rng::forward_range auto &&r, auto &&pred) { - using count_type = rng::range_difference_t; - +inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) { + if (mp::use_sycl()) { + dr::drlog.debug(" with DPL\n"); #ifdef SYCL_LANGUAGE_VERSION - if (rng::empty(r)) { - return count_type{}; - } - - return std::count_if(mp::dpl_policy(), - dr::__detail::direct_iterator(rng::begin(r)), - dr::__detail::direct_iterator(rng::end(r)), pred); + return std::count_if(mp::dpl_policy(), + dr::__detail::direct_iterator(rng::begin(r)), + dr::__detail::direct_iterator(rng::end(r)), pred); #else - assert(false); - return count_type{}; + assert(false); + return rng::range_difference_t{}; #endif + } else { + dr::drlog.debug(" with CPU\n"); + return std::count_if(std::execution::par_unseq, + dr::__detail::direct_iterator(rng::begin(r)), + dr::__detail::direct_iterator(rng::end(r)), pred); + } } template @@ -56,15 +46,8 @@ auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { // Count within the local segments auto count = [=](auto &&r) { assert(rng::size(r) > 0); - if (mp::use_sycl()) { - dr::drlog.debug(" with DPL\n"); - return dpl_count_if(r, pred); - } else { - dr::drlog.debug(" with CPU\n"); - return std_count_if(r, pred); - } + return count_if_local(r, pred); }; - auto locals = rng::views::transform(local_segments(dr), count); auto local = add_counts(locals); @@ -97,63 +80,58 @@ auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { namespace dr::mp { -// -// Ranges -// - -// range, elem, w/wo root - -template -auto count(std::size_t root, DR &&dr, const T &value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(root, true, dr, pred); -} - -template -auto count(DR &&dr, const T &value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(0, false, dr, pred); -} - -// range, predicate, w/wo root +class count_fn_ { +public: + template + auto operator()(std::size_t root, DR &&dr, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, dr, pred); + } -template -auto count_if(std::size_t root, DR &&dr, auto &&pred) { - return __detail::count_if(root, true, dr, pred); -} + template + auto operator()(DR &&dr, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, dr, pred); + } -template auto count_if(DR &&dr, auto &&pred) { - return __detail::count_if(0, false, dr, pred); -} + template + auto operator()(std::size_t root, DI first, DI last, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, rng::subrange(first, last), pred); + } -// -// Iterators -// + template + auto operator()(DI first, DI last, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, rng::subrange(first, last), pred); + } +}; -// range, elem, w/wo root +inline constexpr count_fn_ count; -template -auto count(std::size_t root, DI first, DI last, const T &value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(root, true, rng::subrange(first, last), pred); -} +class count_if_fn_ { +public: + template + auto operator()(std::size_t root, DR &&dr, auto &&pred) const { + return __detail::count_if(root, true, dr, pred); + } -template -auto count(DI first, DI last, const T &value) { - auto pred = [=](auto &&v) { return v == value; }; - return __detail::count_if(0, false, rng::subrange(first, last), pred); -} + template + auto operator()(DR &&dr, auto &&pred) const { + return __detail::count_if(0, false, dr, pred); + } -// range, predicate, w/wo root + template + auto operator()(std::size_t root, DI first, DI last, auto &&pred) const { + return __detail::count_if(root, true, rng::subrange(first, last), pred); + } -template -auto count_if(std::size_t root, DI first, DI last, auto &&pred) { - return __detail::count_if(root, true, rng::subrange(first, last), pred); -} + template + auto operator()(DI first, DI last, auto &&pred) const { + return __detail::count_if(0, false, rng::subrange(first, last), pred); + } +}; -template -auto count_if(DI first, DI last, auto &&pred) { - return __detail::count_if(0, false, rng::subrange(first, last), pred); -} +inline constexpr count_if_fn_ count_if; }; // namespace dr::mp diff --git a/test/gtest/common/count.cpp b/test/gtest/common/count.cpp index 6727c0bae3..96feef0f0a 100644 --- a/test/gtest/common/count.cpp +++ b/test/gtest/common/count.cpp @@ -12,7 +12,12 @@ template class Count : public testing::Test { TYPED_TEST_SUITE(Count, AllTypes); TYPED_TEST(Count, BasicFirstElem) { - Ops1 ops(10); + std::vector vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 }; + + Ops1 ops(vec.size()); + ops.vec = vec; + xp::copy(ops.vec, ops.dist_vec.begin()); + auto value = *ops.vec.begin(); EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value), @@ -20,8 +25,13 @@ TYPED_TEST(Count, BasicFirstElem) { } TYPED_TEST(Count, BasicFirstElemIf) { - Ops1 ops(10); - auto value = *ops.vec.begin(); + std::vector vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 }; + + Ops1 ops(vec.size()); + ops.vec = vec; + xp::copy(ops.vec, ops.dist_vec.begin()); + + auto value = *vec.begin(); auto pred = [=](auto &&v) { return v == value; }; EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), @@ -29,7 +39,7 @@ TYPED_TEST(Count, BasicFirstElemIf) { } TYPED_TEST(Count, FirstElemsIf) { - Ops1 ops(10); + Ops1 ops(20); auto pred = [=](auto &&v) { return v < 5; }; EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), From 755c8965fbfde23c57824e7c5d4de2ef52feb288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Mon, 5 Aug 2024 18:57:03 +0200 Subject: [PATCH 5/7] more code review fixes --- include/dr/mp/algorithms/count.hpp | 1 - test/gtest/common/count.cpp | 23 ++++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/include/dr/mp/algorithms/count.hpp b/include/dr/mp/algorithms/count.hpp index e5f606bcd4..a66c7196a7 100644 --- a/include/dr/mp/algorithms/count.hpp +++ b/include/dr/mp/algorithms/count.hpp @@ -21,7 +21,6 @@ inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) { dr::__detail::direct_iterator(rng::end(r)), pred); #else assert(false); - return rng::range_difference_t{}; #endif } else { dr::drlog.debug(" with CPU\n"); diff --git a/test/gtest/common/count.cpp b/test/gtest/common/count.cpp index 96feef0f0a..0fcb3b2d92 100644 --- a/test/gtest/common/count.cpp +++ b/test/gtest/common/count.cpp @@ -11,6 +11,18 @@ template class Count : public testing::Test { TYPED_TEST_SUITE(Count, AllTypes); +TYPED_TEST(Count, EmptyIf) { + std::vector vec; + + Ops1 ops(0); + + auto pred = [=](auto &&v) { return true; }; + + EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 0); + EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), + xp::count_if(ops.dist_vec, pred)); +} + TYPED_TEST(Count, BasicFirstElem) { std::vector vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 }; @@ -20,6 +32,7 @@ TYPED_TEST(Count, BasicFirstElem) { auto value = *ops.vec.begin(); + EXPECT_EQ(xp::count(ops.dist_vec, value), 4); EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value), xp::count(ops.dist_vec, value)); } @@ -34,14 +47,22 @@ TYPED_TEST(Count, BasicFirstElemIf) { auto value = *vec.begin(); auto pred = [=](auto &&v) { return v == value; }; + EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 4); EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), xp::count_if(ops.dist_vec, pred)); } TYPED_TEST(Count, FirstElemsIf) { - Ops1 ops(20); + std::vector vec(20); + std::iota(vec.begin(), vec.end(), 0); + + Ops1 ops(vec.size()); + ops.vec = vec; + xp::copy(ops.vec, ops.dist_vec.begin()); + auto pred = [=](auto &&v) { return v < 5; }; + EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 5); EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), xp::count_if(ops.dist_vec, pred)); } From e98de3bc478b902909ed07e05e1631c24d030eb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Fri, 16 Aug 2024 13:14:33 +0200 Subject: [PATCH 6/7] removed redundant conditional --- include/dr/mp/algorithms/count.hpp | 50 ++++++++++++------------------ 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/include/dr/mp/algorithms/count.hpp b/include/dr/mp/algorithms/count.hpp index a66c7196a7..78e33557db 100644 --- a/include/dr/mp/algorithms/count.hpp +++ b/include/dr/mp/algorithms/count.hpp @@ -39,39 +39,29 @@ auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { return count_type{}; } - if (aligned(dr)) { - dr::drlog.debug("Parallel count\n"); - - // Count within the local segments - auto count = [=](auto &&r) { - assert(rng::size(r) > 0); - return count_if_local(r, pred); - }; - auto locals = rng::views::transform(local_segments(dr), count); - auto local = add_counts(locals); - - std::vector all(comm.size()); - if (root_provided) { - // Everyone gathers to root, only root adds up the counts - comm.gather(local, std::span{all}, root); - if (root == comm.rank()) { - return add_counts(all); - } else { - return count_type{}; - } - } else { - // Everyone gathers and everyone adds up the counts - comm.all_gather(local, all); + dr::drlog.debug("Parallel count\n"); + + // Count within the local segments + auto count = [=](auto &&r) { + assert(rng::size(r) > 0); + return count_if_local(r, pred); + }; + auto locals = rng::views::transform(local_segments(dr), count); + auto local = add_counts(locals); + + std::vector all(comm.size()); + if (root_provided) { + // Everyone gathers to root, only root adds up the counts + comm.gather(local, std::span{all}, root); + if (root == comm.rank()) { return add_counts(all); + } else { + return count_type{}; } } else { - dr::drlog.debug("Serial count\n"); - count_type result{}; - if (!root_provided || root == comm.rank()) { - result = add_counts(dr); - } - mp::barrier(); - return result; + // Everyone gathers and everyone adds up the counts + comm.all_gather(local, all); + return add_counts(all); } } From f31b80ce104d992508db41430e5fdcfc7f0713ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20G=C5=82=C4=99bocki?= Date: Sun, 25 Aug 2024 17:07:16 +0200 Subject: [PATCH 7/7] fixes according to pre-commit checks --- include/dr/mp/algorithms/count.hpp | 6 +++--- test/gtest/common/count.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/dr/mp/algorithms/count.hpp b/include/dr/mp/algorithms/count.hpp index 78e33557db..3616cded70 100644 --- a/include/dr/mp/algorithms/count.hpp +++ b/include/dr/mp/algorithms/count.hpp @@ -71,13 +71,13 @@ namespace dr::mp { class count_fn_ { public: - template + template auto operator()(std::size_t root, DR &&dr, const T &value) const { auto pred = [=](auto &&v) { return v == value; }; return __detail::count_if(root, true, dr, pred); } - template + template auto operator()(DR &&dr, const T &value) const { auto pred = [=](auto &&v) { return v == value; }; return __detail::count_if(0, false, dr, pred); @@ -105,7 +105,7 @@ class count_if_fn_ { return __detail::count_if(root, true, dr, pred); } - template + template auto operator()(DR &&dr, auto &&pred) const { return __detail::count_if(0, false, dr, pred); } diff --git a/test/gtest/common/count.cpp b/test/gtest/common/count.cpp index 0fcb3b2d92..f6442055fd 100644 --- a/test/gtest/common/count.cpp +++ b/test/gtest/common/count.cpp @@ -24,7 +24,7 @@ TYPED_TEST(Count, EmptyIf) { } TYPED_TEST(Count, BasicFirstElem) { - std::vector vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 }; + std::vector vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7}; Ops1 ops(vec.size()); ops.vec = vec; @@ -38,7 +38,7 @@ TYPED_TEST(Count, BasicFirstElem) { } TYPED_TEST(Count, BasicFirstElemIf) { - std::vector vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 }; + std::vector vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7}; Ops1 ops(vec.size()); ops.vec = vec;