diff --git a/include/dr/mp.hpp b/include/dr/mp.hpp index 5e9729b16a..f9598bbcd8 100644 --- a/include/dr/mp.hpp +++ b/include/dr/mp.hpp @@ -65,6 +65,7 @@ #include #include #include +#include #include #include #include diff --git a/include/dr/mp/algorithms/count.hpp b/include/dr/mp/algorithms/count.hpp new file mode 100644 index 0000000000..3616cded70 --- /dev/null +++ b/include/dr/mp/algorithms/count.hpp @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: Intel Corporation +// +// SPDX-License-Identifier: BSD-3-Clause + +#pragma once + +namespace dr::mp::__detail { + +inline auto add_counts(rng::forward_range auto &&r) { + rng::range_difference_t zero{}; + + return std::accumulate(rng::begin(r), rng::end(r), zero); +} + +inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) { + if (mp::use_sycl()) { + dr::drlog.debug(" with DPL\n"); +#ifdef SYCL_LANGUAGE_VERSION + return std::count_if(mp::dpl_policy(), + dr::__detail::direct_iterator(rng::begin(r)), + dr::__detail::direct_iterator(rng::end(r)), pred); +#else + assert(false); +#endif + } else { + dr::drlog.debug(" with CPU\n"); + return std::count_if(std::execution::par_unseq, + dr::__detail::direct_iterator(rng::begin(r)), + dr::__detail::direct_iterator(rng::end(r)), pred); + } +} + +template +auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) { + using count_type = rng::range_difference_t; + auto comm = mp::default_comm(); + + if (rng::empty(dr)) { + return count_type{}; + } + + dr::drlog.debug("Parallel count\n"); + + // Count within the local segments + auto count = [=](auto &&r) { + assert(rng::size(r) > 0); + return count_if_local(r, pred); + }; + auto locals = rng::views::transform(local_segments(dr), count); + auto local = add_counts(locals); + + std::vector all(comm.size()); + if (root_provided) { + // Everyone gathers to root, only root adds up the counts + comm.gather(local, std::span{all}, root); + if (root == comm.rank()) { + return add_counts(all); + } else { + return count_type{}; + } + } else { + // Everyone gathers and everyone adds up the counts + comm.all_gather(local, all); + return add_counts(all); + } +} + +} // namespace dr::mp::__detail + +namespace dr::mp { + +class count_fn_ { +public: + template + auto operator()(std::size_t root, DR &&dr, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, dr, pred); + } + + template + auto operator()(DR &&dr, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, dr, pred); + } + + template + auto operator()(std::size_t root, DI first, DI last, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(root, true, rng::subrange(first, last), pred); + } + + template + auto operator()(DI first, DI last, const T &value) const { + auto pred = [=](auto &&v) { return v == value; }; + return __detail::count_if(0, false, rng::subrange(first, last), pred); + } +}; + +inline constexpr count_fn_ count; + +class count_if_fn_ { +public: + template + auto operator()(std::size_t root, DR &&dr, auto &&pred) const { + return __detail::count_if(root, true, dr, pred); + } + + template + auto operator()(DR &&dr, auto &&pred) const { + return __detail::count_if(0, false, dr, pred); + } + + template + auto operator()(std::size_t root, DI first, DI last, auto &&pred) const { + return __detail::count_if(root, true, rng::subrange(first, last), pred); + } + + template + auto operator()(DI first, DI last, auto &&pred) const { + return __detail::count_if(0, false, rng::subrange(first, last), pred); + } +}; + +inline constexpr count_if_fn_ count_if; + +}; // namespace dr::mp diff --git a/test/gtest/common/count.cpp b/test/gtest/common/count.cpp new file mode 100644 index 0000000000..f6442055fd --- /dev/null +++ b/test/gtest/common/count.cpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: Intel Corporation +// +// SPDX-License-Identifier: BSD-3-Clause + +#include "xp-tests.hpp" + +// Fixture +template class Count : public testing::Test { +protected: +}; + +TYPED_TEST_SUITE(Count, AllTypes); + +TYPED_TEST(Count, EmptyIf) { + std::vector vec; + + Ops1 ops(0); + + auto pred = [=](auto &&v) { return true; }; + + EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 0); + EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), + xp::count_if(ops.dist_vec, pred)); +} + +TYPED_TEST(Count, BasicFirstElem) { + std::vector vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7}; + + Ops1 ops(vec.size()); + ops.vec = vec; + xp::copy(ops.vec, ops.dist_vec.begin()); + + auto value = *ops.vec.begin(); + + EXPECT_EQ(xp::count(ops.dist_vec, value), 4); + EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value), + xp::count(ops.dist_vec, value)); +} + +TYPED_TEST(Count, BasicFirstElemIf) { + std::vector vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7}; + + Ops1 ops(vec.size()); + ops.vec = vec; + xp::copy(ops.vec, ops.dist_vec.begin()); + + auto value = *vec.begin(); + auto pred = [=](auto &&v) { return v == value; }; + + EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 4); + EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), + xp::count_if(ops.dist_vec, pred)); +} + +TYPED_TEST(Count, FirstElemsIf) { + std::vector vec(20); + std::iota(vec.begin(), vec.end(), 0); + + Ops1 ops(vec.size()); + ops.vec = vec; + xp::copy(ops.vec, ops.dist_vec.begin()); + + auto pred = [=](auto &&v) { return v < 5; }; + + EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 5); + EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred), + xp::count_if(ops.dist_vec, pred)); +} diff --git a/test/gtest/mp/CMakeLists.txt b/test/gtest/mp/CMakeLists.txt index 8138b35fb0..32f26d120a 100644 --- a/test/gtest/mp/CMakeLists.txt +++ b/test/gtest/mp/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable( mp-tests.cpp ../common/all.cpp ../common/copy.cpp + ../common/count.cpp ../common/counted.cpp ../common/distributed_vector.cpp ../common/drop.cpp @@ -57,7 +58,7 @@ add_executable( add_executable(mp-quick-test mp-tests.cpp - ../common/equal.cpp + ../common/count.cpp ) # cmake-format: on