Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/dr/mp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include <dr/mp/views/mdspan_view.hpp>
#include <dr/mp/views/submdspan_view.hpp>
#include <dr/mp/algorithms/copy.hpp>
#include <dr/mp/algorithms/count.hpp>
#include <dr/mp/algorithms/equal.hpp>
#include <dr/mp/algorithms/fill.hpp>
#include <dr/mp/algorithms/for_each.hpp>
Expand Down
126 changes: 126 additions & 0 deletions include/dr/mp/algorithms/count.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// SPDX-FileCopyrightText: Intel Corporation
//
// SPDX-License-Identifier: BSD-3-Clause

#pragma once

namespace dr::mp::__detail {

inline auto add_counts(rng::forward_range auto &&r) {
rng::range_difference_t<decltype(r)> zero{};

return std::accumulate(rng::begin(r), rng::end(r), zero);
}

inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) {
if (mp::use_sycl()) {
dr::drlog.debug(" with DPL\n");
#ifdef SYCL_LANGUAGE_VERSION
return std::count_if(mp::dpl_policy(),
dr::__detail::direct_iterator(rng::begin(r)),
dr::__detail::direct_iterator(rng::end(r)), pred);
#else
assert(false);
#endif
} else {
dr::drlog.debug(" with CPU\n");
return std::count_if(std::execution::par_unseq,
dr::__detail::direct_iterator(rng::begin(r)),
dr::__detail::direct_iterator(rng::end(r)), pred);
}
}

template <dr::distributed_range DR>
auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) {
using count_type = rng::range_difference_t<decltype(dr)>;
auto comm = mp::default_comm();

if (rng::empty(dr)) {
return count_type{};
}

dr::drlog.debug("Parallel count\n");

// Count within the local segments
auto count = [=](auto &&r) {
assert(rng::size(r) > 0);
return count_if_local(r, pred);
};
auto locals = rng::views::transform(local_segments(dr), count);
auto local = add_counts(locals);

std::vector<count_type> all(comm.size());
if (root_provided) {
// Everyone gathers to root, only root adds up the counts
comm.gather(local, std::span{all}, root);
if (root == comm.rank()) {
return add_counts(all);
} else {
return count_type{};
}
} else {
// Everyone gathers and everyone adds up the counts
comm.all_gather(local, all);
return add_counts(all);
}
}

} // namespace dr::mp::__detail

namespace dr::mp {

class count_fn_ {
public:
template <typename T, dr::distributed_range DR>
auto operator()(std::size_t root, DR &&dr, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(root, true, dr, pred);
}

template <typename T, dr::distributed_range DR>
auto operator()(DR &&dr, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(0, false, dr, pred);
}

template <typename T, dr::distributed_iterator DI>
auto operator()(std::size_t root, DI first, DI last, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(root, true, rng::subrange(first, last), pred);
}

template <typename T, dr::distributed_iterator DI>
auto operator()(DI first, DI last, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(0, false, rng::subrange(first, last), pred);
}
};

inline constexpr count_fn_ count;

class count_if_fn_ {
public:
template <dr::distributed_range DR>
auto operator()(std::size_t root, DR &&dr, auto &&pred) const {
return __detail::count_if(root, true, dr, pred);
}

template <dr::distributed_range DR>
auto operator()(DR &&dr, auto &&pred) const {
return __detail::count_if(0, false, dr, pred);
}

template <dr::distributed_iterator DI>
auto operator()(std::size_t root, DI first, DI last, auto &&pred) const {
return __detail::count_if(root, true, rng::subrange(first, last), pred);
}

template <dr::distributed_iterator DI>
auto operator()(DI first, DI last, auto &&pred) const {
return __detail::count_if(0, false, rng::subrange(first, last), pred);
}
};

inline constexpr count_if_fn_ count_if;

}; // namespace dr::mp
68 changes: 68 additions & 0 deletions test/gtest/common/count.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: Intel Corporation
//
// SPDX-License-Identifier: BSD-3-Clause

#include "xp-tests.hpp"

// Fixture
template <typename T> class Count : public testing::Test {
protected:
};

TYPED_TEST_SUITE(Count, AllTypes);

TYPED_TEST(Count, EmptyIf) {
std::vector<int> vec;

Ops1<TypeParam> ops(0);

auto pred = [=](auto &&v) { return true; };

EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 0);
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
xp::count_if(ops.dist_vec, pred));
}

TYPED_TEST(Count, BasicFirstElem) {
std::vector<int> vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7};

Ops1<TypeParam> ops(vec.size());
ops.vec = vec;
xp::copy(ops.vec, ops.dist_vec.begin());

auto value = *ops.vec.begin();

EXPECT_EQ(xp::count(ops.dist_vec, value), 4);
EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value),
xp::count(ops.dist_vec, value));
}

TYPED_TEST(Count, BasicFirstElemIf) {
std::vector<int> vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7};

Ops1<TypeParam> ops(vec.size());
ops.vec = vec;
xp::copy(ops.vec, ops.dist_vec.begin());

auto value = *vec.begin();
auto pred = [=](auto &&v) { return v == value; };

EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 4);
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
xp::count_if(ops.dist_vec, pred));
}

TYPED_TEST(Count, FirstElemsIf) {
std::vector<int> vec(20);
std::iota(vec.begin(), vec.end(), 0);

Ops1<TypeParam> ops(vec.size());
ops.vec = vec;
xp::copy(ops.vec, ops.dist_vec.begin());

auto pred = [=](auto &&v) { return v < 5; };

EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 5);
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
xp::count_if(ops.dist_vec, pred));
}
3 changes: 2 additions & 1 deletion test/gtest/mp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_executable(
mp-tests.cpp
../common/all.cpp
../common/copy.cpp
../common/count.cpp
../common/counted.cpp
../common/distributed_vector.cpp
../common/drop.cpp
Expand Down Expand Up @@ -57,7 +58,7 @@ add_executable(

add_executable(mp-quick-test
mp-tests.cpp
../common/equal.cpp
../common/count.cpp
)
# cmake-format: on

Expand Down