Skip to content
This repository was archived by the owner on Sep 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/dr/mp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include <dr/mp/views/mdspan_view.hpp>
#include <dr/mp/views/submdspan_view.hpp>
#include <dr/mp/algorithms/copy.hpp>
#include <dr/mp/algorithms/count.hpp>
#include <dr/mp/algorithms/equal.hpp>
#include <dr/mp/algorithms/fill.hpp>
#include <dr/mp/algorithms/for_each.hpp>
Expand Down
126 changes: 126 additions & 0 deletions include/dr/mp/algorithms/count.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// SPDX-FileCopyrightText: Intel Corporation
//
// SPDX-License-Identifier: BSD-3-Clause

#pragma once

namespace dr::mp::__detail {

inline auto add_counts(rng::forward_range auto &&r) {
rng::range_difference_t<decltype(r)> zero{};

return std::accumulate(rng::begin(r), rng::end(r), zero);
}

inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) {
if (mp::use_sycl()) {
dr::drlog.debug(" with DPL\n");
#ifdef SYCL_LANGUAGE_VERSION
return std::count_if(mp::dpl_policy(),
dr::__detail::direct_iterator(rng::begin(r)),
dr::__detail::direct_iterator(rng::end(r)), pred);
#else
assert(false);
#endif
} else {
dr::drlog.debug(" with CPU\n");
return std::count_if(std::execution::par_unseq,
dr::__detail::direct_iterator(rng::begin(r)),
dr::__detail::direct_iterator(rng::end(r)), pred);
}
}

template <dr::distributed_range DR>
auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) {
using count_type = rng::range_difference_t<decltype(dr)>;
auto comm = mp::default_comm();

if (rng::empty(dr)) {
return count_type{};
}

dr::drlog.debug("Parallel count\n");

// Count within the local segments
auto count = [=](auto &&r) {
assert(rng::size(r) > 0);
return count_if_local(r, pred);
};
auto locals = rng::views::transform(local_segments(dr), count);
auto local = add_counts(locals);

std::vector<count_type> all(comm.size());
if (root_provided) {
// Everyone gathers to root, only root adds up the counts
comm.gather(local, std::span{all}, root);
if (root == comm.rank()) {
return add_counts(all);
} else {
return count_type{};
}
} else {
// Everyone gathers and everyone adds up the counts
comm.all_gather(local, all);
return add_counts(all);
}
}

} // namespace dr::mp::__detail

namespace dr::mp {

class count_fn_ {
public:
template <typename T, dr::distributed_range DR>
auto operator()(std::size_t root, DR &&dr, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(root, true, dr, pred);
}

template <typename T, dr::distributed_range DR>
auto operator()(DR &&dr, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(0, false, dr, pred);
}

template <typename T, dr::distributed_iterator DI>
auto operator()(std::size_t root, DI first, DI last, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(root, true, rng::subrange(first, last), pred);
}

template <typename T, dr::distributed_iterator DI>
auto operator()(DI first, DI last, const T &value) const {
auto pred = [=](auto &&v) { return v == value; };
return __detail::count_if(0, false, rng::subrange(first, last), pred);
}
};

inline constexpr count_fn_ count;

class count_if_fn_ {
public:
template <dr::distributed_range DR>
auto operator()(std::size_t root, DR &&dr, auto &&pred) const {
return __detail::count_if(root, true, dr, pred);
}

template <dr::distributed_range DR>
auto operator()(DR &&dr, auto &&pred) const {
return __detail::count_if(0, false, dr, pred);
}

template <dr::distributed_iterator DI>
auto operator()(std::size_t root, DI first, DI last, auto &&pred) const {
return __detail::count_if(root, true, rng::subrange(first, last), pred);
}

template <dr::distributed_iterator DI>
auto operator()(DI first, DI last, auto &&pred) const {
return __detail::count_if(0, false, rng::subrange(first, last), pred);
}
};

inline constexpr count_if_fn_ count_if;

}; // namespace dr::mp
68 changes: 68 additions & 0 deletions test/gtest/common/count.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: Intel Corporation
//
// SPDX-License-Identifier: BSD-3-Clause

#include "xp-tests.hpp"

// Fixture
template <typename T> class Count : public testing::Test {
protected:
};

TYPED_TEST_SUITE(Count, AllTypes);

TYPED_TEST(Count, EmptyIf) {
std::vector<int> vec;

Ops1<TypeParam> ops(0);

auto pred = [=](auto &&v) { return true; };

EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 0);
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
xp::count_if(ops.dist_vec, pred));
}

TYPED_TEST(Count, BasicFirstElem) {
std::vector<int> vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 };

Ops1<TypeParam> ops(vec.size());
ops.vec = vec;
xp::copy(ops.vec, ops.dist_vec.begin());

auto value = *ops.vec.begin();

EXPECT_EQ(xp::count(ops.dist_vec, value), 4);
EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value),
xp::count(ops.dist_vec, value));
}

TYPED_TEST(Count, BasicFirstElemIf) {
std::vector<int> vec { 1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7 };

Ops1<TypeParam> ops(vec.size());
ops.vec = vec;
xp::copy(ops.vec, ops.dist_vec.begin());

auto value = *vec.begin();
auto pred = [=](auto &&v) { return v == value; };

EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 4);
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
xp::count_if(ops.dist_vec, pred));
}

TYPED_TEST(Count, FirstElemsIf) {
std::vector<int> vec(20);
std::iota(vec.begin(), vec.end(), 0);

Ops1<TypeParam> ops(vec.size());
ops.vec = vec;
xp::copy(ops.vec, ops.dist_vec.begin());

auto pred = [=](auto &&v) { return v < 5; };

EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 5);
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
xp::count_if(ops.dist_vec, pred));
}
3 changes: 2 additions & 1 deletion test/gtest/mp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_executable(
mp-tests.cpp
../common/all.cpp
../common/copy.cpp
../common/count.cpp
../common/counted.cpp
../common/distributed_vector.cpp
../common/drop.cpp
Expand Down Expand Up @@ -57,7 +58,7 @@ add_executable(

add_executable(mp-quick-test
mp-tests.cpp
../common/equal.cpp
../common/count.cpp
)
# cmake-format: on

Expand Down