Skip to content

Commit 4bc3385

Browse files
committed
Stream-ordered wait{_any,_all}, fix wait_any implementation
1 parent b153676 commit 4bc3385

File tree

3 files changed

+150
-10
lines changed

3 files changed

+150
-10
lines changed

src/KokkosComm/mpi/req.hpp

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,34 +64,68 @@ class Req<Mpi> {
6464
private:
6565
std::shared_ptr<Record> record_;
6666

67+
template <KokkosExecutionSpace ExecSpace>
68+
friend void wait(const ExecSpace &space, Req<Mpi> req);
6769
friend void wait(Req<Mpi> req);
70+
template <KokkosExecutionSpace ExecSpace>
71+
friend void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs);
6872
friend void wait_all(std::vector<Req<Mpi>> &reqs);
73+
template <KokkosExecutionSpace ExecSpace>
74+
friend int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs);
6975
friend int wait_any(std::vector<Req<Mpi>> &reqs);
7076
};
7177

72-
inline void wait(Req<Mpi> req) {
78+
template <KokkosExecutionSpace ExecSpace>
79+
void wait(const ExecSpace &space, Req<Mpi> req) {
80+
/* Semantically this only guarantees that `space` is waiting for request to complete. For the MPI host API, we have no
81+
* choice but to fence the space before waiting on the host.*/
82+
space.fence();
7383
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
7484
for (auto &f : req.record_->postWaits_) {
7585
f();
7686
}
7787
req.record_->postWaits_.clear();
7888
}
7989

80-
inline void wait_all(std::vector<Req<Mpi>> &reqs) {
90+
inline void wait(Req<Mpi> req) { wait(Kokkos::DefaultExecutionSpace(), req); }
91+
92+
template <KokkosExecutionSpace ExecSpace>
93+
void wait_all(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
94+
space.fence();
8195
for (Req<Mpi> &req : reqs) {
82-
wait(req);
96+
MPI_Wait(&req.mpi_request(), MPI_STATUS_IGNORE);
97+
for (auto &f : req.record_->postWaits_) {
98+
f();
99+
}
100+
req.record_->postWaits_.clear();
83101
}
84102
}
85103

86-
inline int wait_any(std::vector<Req<Mpi>> &reqs) {
87-
for (size_t i = 0; i < reqs.size(); ++i) {
88-
int completed;
89-
MPI_Test(&(reqs[i].mpi_request()), &completed, MPI_STATUS_IGNORE);
90-
if (completed) {
91-
return true;
104+
inline void wait_all(std::vector<Req<Mpi>> &reqs) { wait_all(Kokkos::DefaultExecutionSpace(), reqs); }
105+
106+
template <KokkosExecutionSpace ExecSpace>
107+
int wait_any(const ExecSpace &space, std::vector<Req<Mpi>> &reqs) {
108+
if (reqs.empty()) {
109+
return -1;
110+
}
111+
112+
space.fence();
113+
while (true) { // wait until something is done
114+
for (size_t i = 0; i < reqs.size(); ++i) {
115+
int completed;
116+
Req<Mpi> &req = reqs[i];
117+
MPI_Test(&(req.mpi_request()), &completed, MPI_STATUS_IGNORE);
118+
if (completed) {
119+
for (auto &f : req.record_->postWaits_) {
120+
f();
121+
}
122+
req.record_->postWaits_.clear();
123+
return i;
124+
}
92125
}
93126
}
94-
return false;
95127
}
96128

129+
inline int wait_any(std::vector<Req<Mpi>> &reqs) { return wait_any(Kokkos::DefaultExecutionSpace(), reqs); }
130+
97131
} // namespace KokkosComm

unit_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ target_sources(
8686
mpi/test_alltoall.cpp
8787
mpi/test_reduce.cpp
8888
mpi/test_allgather.cpp
89+
mpi/test_waitany.cpp
8990
)
9091
target_link_libraries(
9192
test-main

unit_tests/mpi/test_waitany.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2022) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#include <gtest/gtest.h>
18+
#include <type_traits>
19+
#include <algorithm> // iota
20+
#include <random>
21+
22+
#include "KokkosComm/KokkosComm.hpp"
23+
24+
namespace {
25+
26+
using namespace KokkosComm::mpi;
27+
28+
template <typename T>
29+
class MpiWaitAny : public testing::Test {
30+
public:
31+
using Scalar = T;
32+
};
33+
34+
using ScalarTypes = ::testing::Types<int, double, Kokkos::complex<float>>;
35+
TYPED_TEST_SUITE(MpiWaitAny, ScalarTypes);
36+
37+
template <KokkosComm::KokkosExecutionSpace ExecSpace, typename Scalar>
38+
void wait_any() {
39+
using TestView = Kokkos::View<Scalar *>;
40+
41+
int rank, size;
42+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
43+
MPI_Comm_size(MPI_COMM_WORLD, &size);
44+
if (size < 2) {
45+
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
46+
}
47+
48+
constexpr size_t numMsg = 128;
49+
ExecSpace space;
50+
std::vector<KokkosComm::Req<>> reqs;
51+
std::vector<TestView> views;
52+
53+
for (size_t i = 0; i < numMsg; ++i) {
54+
views.push_back(TestView(std::to_string(i), i));
55+
}
56+
57+
constexpr unsigned int SEED = 31337;
58+
std::random_device rd;
59+
std::mt19937 g(SEED);
60+
61+
// random send/recv order
62+
std::vector<size_t> order(numMsg);
63+
std::iota(order.begin(), order.end(), size_t(0));
64+
std::shuffle(order.begin(), order.end(), g);
65+
66+
KokkosComm::Handle<ExecSpace, KokkosComm::Mpi> h(space, MPI_COMM_WORLD);
67+
68+
if (0 == rank) {
69+
constexpr int dst = 1;
70+
71+
for (size_t i : order) {
72+
reqs.push_back(KokkosComm::send(h, views[i], dst));
73+
}
74+
75+
for (size_t i = 0; i < numMsg; ++i) {
76+
reqs.erase(reqs.begin() + KokkosComm::wait_any(reqs));
77+
}
78+
} else if (1 == rank) {
79+
constexpr int src = 0;
80+
81+
for (size_t i : order) {
82+
reqs.push_back(KokkosComm::recv(h, views[i], src));
83+
}
84+
85+
for (size_t i = 0; i < numMsg; ++i) {
86+
reqs.erase(reqs.begin() + KokkosComm::wait_any(reqs));
87+
}
88+
}
89+
}
90+
91+
// TODO: test call on no requests
92+
93+
TYPED_TEST(MpiWaitAny, default_execution_space) {
94+
wait_any<Kokkos::DefaultExecutionSpace, typename TestFixture::Scalar>();
95+
}
96+
97+
TYPED_TEST(MpiWaitAny, default_host_execution_space) {
98+
if constexpr (std::is_same_v<Kokkos::DefaultHostExecutionSpace, Kokkos::DefaultExecutionSpace>) {
99+
GTEST_SKIP() << "Skipping test: DefaultHostExecSpace = DefaultExecSpace";
100+
} else {
101+
wait_any<Kokkos::DefaultHostExecutionSpace, typename TestFixture::Scalar>();
102+
}
103+
}
104+
105+
} // namespace

0 commit comments

Comments
 (0)