Skip to content

Commit 02bbffb

Browse files
quazuoFilip Głębocki
andauthored
added count to mp algorithms (#812)
* added count to mp algorithms --------- Co-authored-by: Filip Głębocki <[email protected]>
1 parent 104e374 commit 02bbffb

File tree

4 files changed

+197
-1
lines changed

4 files changed

+197
-1
lines changed

include/dr/mp.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
#include <dr/mp/views/mdspan_view.hpp>
6666
#include <dr/mp/views/submdspan_view.hpp>
6767
#include <dr/mp/algorithms/copy.hpp>
68+
#include <dr/mp/algorithms/count.hpp>
6869
#include <dr/mp/algorithms/equal.hpp>
6970
#include <dr/mp/algorithms/fill.hpp>
7071
#include <dr/mp/algorithms/for_each.hpp>

include/dr/mp/algorithms/count.hpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// SPDX-FileCopyrightText: Intel Corporation
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#pragma once
6+
7+
namespace dr::mp::__detail {
8+
9+
inline auto add_counts(rng::forward_range auto &&r) {
10+
rng::range_difference_t<decltype(r)> zero{};
11+
12+
return std::accumulate(rng::begin(r), rng::end(r), zero);
13+
}
14+
15+
inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) {
16+
if (mp::use_sycl()) {
17+
dr::drlog.debug(" with DPL\n");
18+
#ifdef SYCL_LANGUAGE_VERSION
19+
return std::count_if(mp::dpl_policy(),
20+
dr::__detail::direct_iterator(rng::begin(r)),
21+
dr::__detail::direct_iterator(rng::end(r)), pred);
22+
#else
23+
assert(false);
24+
#endif
25+
} else {
26+
dr::drlog.debug(" with CPU\n");
27+
return std::count_if(std::execution::par_unseq,
28+
dr::__detail::direct_iterator(rng::begin(r)),
29+
dr::__detail::direct_iterator(rng::end(r)), pred);
30+
}
31+
}
32+
33+
template <dr::distributed_range DR>
34+
auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) {
35+
using count_type = rng::range_difference_t<decltype(dr)>;
36+
auto comm = mp::default_comm();
37+
38+
if (rng::empty(dr)) {
39+
return count_type{};
40+
}
41+
42+
dr::drlog.debug("Parallel count\n");
43+
44+
// Count within the local segments
45+
auto count = [=](auto &&r) {
46+
assert(rng::size(r) > 0);
47+
return count_if_local(r, pred);
48+
};
49+
auto locals = rng::views::transform(local_segments(dr), count);
50+
auto local = add_counts(locals);
51+
52+
std::vector<count_type> all(comm.size());
53+
if (root_provided) {
54+
// Everyone gathers to root, only root adds up the counts
55+
comm.gather(local, std::span{all}, root);
56+
if (root == comm.rank()) {
57+
return add_counts(all);
58+
} else {
59+
return count_type{};
60+
}
61+
} else {
62+
// Everyone gathers and everyone adds up the counts
63+
comm.all_gather(local, all);
64+
return add_counts(all);
65+
}
66+
}
67+
68+
} // namespace dr::mp::__detail
69+
70+
namespace dr::mp {
71+
72+
class count_fn_ {
73+
public:
74+
template <typename T, dr::distributed_range DR>
75+
auto operator()(std::size_t root, DR &&dr, const T &value) const {
76+
auto pred = [=](auto &&v) { return v == value; };
77+
return __detail::count_if(root, true, dr, pred);
78+
}
79+
80+
template <typename T, dr::distributed_range DR>
81+
auto operator()(DR &&dr, const T &value) const {
82+
auto pred = [=](auto &&v) { return v == value; };
83+
return __detail::count_if(0, false, dr, pred);
84+
}
85+
86+
template <typename T, dr::distributed_iterator DI>
87+
auto operator()(std::size_t root, DI first, DI last, const T &value) const {
88+
auto pred = [=](auto &&v) { return v == value; };
89+
return __detail::count_if(root, true, rng::subrange(first, last), pred);
90+
}
91+
92+
template <typename T, dr::distributed_iterator DI>
93+
auto operator()(DI first, DI last, const T &value) const {
94+
auto pred = [=](auto &&v) { return v == value; };
95+
return __detail::count_if(0, false, rng::subrange(first, last), pred);
96+
}
97+
};
98+
99+
inline constexpr count_fn_ count;
100+
101+
class count_if_fn_ {
102+
public:
103+
template <dr::distributed_range DR>
104+
auto operator()(std::size_t root, DR &&dr, auto &&pred) const {
105+
return __detail::count_if(root, true, dr, pred);
106+
}
107+
108+
template <dr::distributed_range DR>
109+
auto operator()(DR &&dr, auto &&pred) const {
110+
return __detail::count_if(0, false, dr, pred);
111+
}
112+
113+
template <dr::distributed_iterator DI>
114+
auto operator()(std::size_t root, DI first, DI last, auto &&pred) const {
115+
return __detail::count_if(root, true, rng::subrange(first, last), pred);
116+
}
117+
118+
template <dr::distributed_iterator DI>
119+
auto operator()(DI first, DI last, auto &&pred) const {
120+
return __detail::count_if(0, false, rng::subrange(first, last), pred);
121+
}
122+
};
123+
124+
inline constexpr count_if_fn_ count_if;
125+
126+
}; // namespace dr::mp

test/gtest/common/count.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SPDX-FileCopyrightText: Intel Corporation
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include "xp-tests.hpp"
6+
7+
// Fixture
8+
template <typename T> class Count : public testing::Test {
9+
protected:
10+
};
11+
12+
TYPED_TEST_SUITE(Count, AllTypes);
13+
14+
TYPED_TEST(Count, EmptyIf) {
15+
std::vector<int> vec;
16+
17+
Ops1<TypeParam> ops(0);
18+
19+
auto pred = [=](auto &&v) { return true; };
20+
21+
EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 0);
22+
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
23+
xp::count_if(ops.dist_vec, pred));
24+
}
25+
26+
TYPED_TEST(Count, BasicFirstElem) {
27+
std::vector<int> vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7};
28+
29+
Ops1<TypeParam> ops(vec.size());
30+
ops.vec = vec;
31+
xp::copy(ops.vec, ops.dist_vec.begin());
32+
33+
auto value = *ops.vec.begin();
34+
35+
EXPECT_EQ(xp::count(ops.dist_vec, value), 4);
36+
EXPECT_EQ(std::count(ops.vec.begin(), ops.vec.end(), value),
37+
xp::count(ops.dist_vec, value));
38+
}
39+
40+
TYPED_TEST(Count, BasicFirstElemIf) {
41+
std::vector<int> vec{1, 2, 3, 1, 1, 3, 4, 1, 5, 6, 7};
42+
43+
Ops1<TypeParam> ops(vec.size());
44+
ops.vec = vec;
45+
xp::copy(ops.vec, ops.dist_vec.begin());
46+
47+
auto value = *vec.begin();
48+
auto pred = [=](auto &&v) { return v == value; };
49+
50+
EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 4);
51+
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
52+
xp::count_if(ops.dist_vec, pred));
53+
}
54+
55+
TYPED_TEST(Count, FirstElemsIf) {
56+
std::vector<int> vec(20);
57+
std::iota(vec.begin(), vec.end(), 0);
58+
59+
Ops1<TypeParam> ops(vec.size());
60+
ops.vec = vec;
61+
xp::copy(ops.vec, ops.dist_vec.begin());
62+
63+
auto pred = [=](auto &&v) { return v < 5; };
64+
65+
EXPECT_EQ(xp::count_if(ops.dist_vec, pred), 5);
66+
EXPECT_EQ(std::count_if(ops.vec.begin(), ops.vec.end(), pred),
67+
xp::count_if(ops.dist_vec, pred));
68+
}

test/gtest/mp/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_executable(
1111
mp-tests.cpp
1212
../common/all.cpp
1313
../common/copy.cpp
14+
../common/count.cpp
1415
../common/counted.cpp
1516
../common/distributed_vector.cpp
1617
../common/drop.cpp
@@ -57,7 +58,7 @@ add_executable(
5758

5859
add_executable(mp-quick-test
5960
mp-tests.cpp
60-
../common/equal.cpp
61+
../common/count.cpp
6162
)
6263
# cmake-format: on
6364

0 commit comments

Comments
 (0)