Skip to content

Commit 9126092

Browse files
committed
Communication mode rework
1 parent 0b2423d commit 9126092

File tree

10 files changed

+123
-94
lines changed

10 files changed

+123
-94
lines changed

perf_tests/test_2dhalo.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,12 @@ void benchmark_2dhalo(benchmark::State &state) {
8282
const int ry = rank / rs;
8383

8484
if (rank < rs * rs) {
85-
auto mode = KokkosComm::DefaultCommMode();
8685
auto space = Kokkos::DefaultExecutionSpace();
8786
// grid of elements, each with 3 properties, and a radius-1 halo
8887
grid_type grid("", nx + 2, ny + 2, nprops);
8988
while (state.KeepRunning()) {
90-
do_iteration(state, MPI_COMM_WORLD,
91-
send_recv<KokkosComm::DefaultCommMode, Kokkos::DefaultExecutionSpace, grid_type>, mode, space, nx,
92-
ny, rx, ry, rs, grid);
89+
do_iteration(state, MPI_COMM_WORLD, send_recv<Kokkos::DefaultExecutionSpace, grid_type>, space, nx, ny, rx, ry,
90+
rs, grid);
9391
}
9492
} else {
9593
while (state.KeepRunning()) {

perf_tests/test_osu_latency.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ void benchmark_osu_latency_Kokkos_Comm_mpi_sendrecv(benchmark::State &state) {
6363
state.SkipWithError("benchmark_osu_latency_KokkosComm needs exactly 2 ranks");
6464
}
6565

66-
auto mode = KokkosComm::DefaultCommMode();
6766
auto space = Kokkos::DefaultExecutionSpace();
6867
using view_type = Kokkos::View<char *>;
6968
view_type a("A", state.range(0));

perf_tests/test_sendrecv.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
#include "KokkosComm.hpp"
2020

21-
template <KokkosComm::CommunicationMode Mode, typename Space, typename View>
22-
void send_recv(benchmark::State &, MPI_Comm comm, const Mode &mode, const Space &space, int rank, const View &v) {
21+
template <KokkosComm::mpi::CommunicationMode Mode, typename Space, typename View>
22+
void send_recv(benchmark::State &, MPI_Comm comm, const Space &space, int rank, const View &v) {
2323
if (0 == rank) {
24-
KokkosComm::mpi::send(space, v, 1, 0, comm);
24+
KokkosComm::mpi::send(space, v, 1, 0, comm, Mode{});
2525
KokkosComm::mpi::recv(space, v, 1, 0, comm);
2626
} else if (1 == rank) {
2727
KokkosComm::mpi::recv(space, v, 0, 0, comm);
28-
KokkosComm::mpi::send(space, v, 0, 0, comm);
28+
KokkosComm::mpi::send(space, v, 0, 0, comm, Mode{});
2929
}
3030
}
3131

@@ -39,15 +39,13 @@ void benchmark_sendrecv(benchmark::State &state) {
3939

4040
using Scalar = double;
4141

42-
auto mode = KokkosComm::DefaultCommMode();
42+
using Mode = KokkosComm::mpi::DefaultCommMode;
4343
auto space = Kokkos::DefaultExecutionSpace();
4444
using view_type = Kokkos::View<Scalar *>;
4545
view_type a("", 1000000);
4646

4747
while (state.KeepRunning()) {
48-
do_iteration(state, MPI_COMM_WORLD,
49-
send_recv<KokkosComm::DefaultCommMode, Kokkos::DefaultExecutionSpace, view_type>, mode, space, rank,
50-
a);
48+
do_iteration(state, MPI_COMM_WORLD, send_recv<Mode, Kokkos::DefaultExecutionSpace, view_type>, space, rank, a);
5149
}
5250

5351
state.SetBytesProcessed(sizeof(Scalar) * state.iterations() * a.size() * 2);

src/KokkosComm_collective.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
namespace KokkosComm {
2727

28-
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, CommunicationSpace CommSpace = DefaultCommunicationSpace>
28+
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
29+
CommunicationSpace CommSpace = DefaultCommunicationSpace>
2930
void barrier(Handle<ExecSpace, CommSpace> &&h) {
3031
Impl::Barrier<ExecSpace, CommSpace>{std::forward<Handle<ExecSpace, CommSpace>>(h)};
3132
}

src/KokkosComm_fwd.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ using FallbackCommunicationSpace = Mpi;
3333
template <CommunicationSpace CommSpace = DefaultCommunicationSpace>
3434
class Req;
3535

36-
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, CommunicationSpace CommSpace = DefaultCommunicationSpace>
36+
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
37+
CommunicationSpace CommSpace = DefaultCommunicationSpace>
3738
class Handle;
3839

3940
namespace Impl {
@@ -44,7 +45,8 @@ struct Recv;
4445
template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
4546
CommunicationSpace CommSpace = DefaultCommunicationSpace>
4647
struct Send;
47-
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, CommunicationSpace CommSpace = DefaultCommunicationSpace>
48+
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
49+
CommunicationSpace CommSpace = DefaultCommunicationSpace>
4850
struct Barrier;
4951

5052
} // namespace Impl

src/mpi/KokkosComm_mpi_commmode.hpp

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,52 @@
1616

1717
#pragma once
1818

19-
namespace KokkosComm::mpi {
20-
// Scoped enumeration to specify the communication mode of a sending operation.
19+
#include <type_traits>
20+
2121
// See section 3.4 of the MPI standard for a complete specification.
22-
enum class CommMode {
23-
// Default mode: lets the user override the send operations behavior at
24-
// compile-time. E.g., this can be set to mode "Synchronous" for debug
25-
// builds by defining KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE.
26-
Default,
27-
// Standard mode: MPI implementation decides whether outgoing messages will
28-
// be buffered. Send operations can be started whether or not a matching
29-
// receive has been started. They may complete before a matching receive is
30-
// started. Standard mode is non-local: successful completion of the send
31-
// operation may depend on the occurrence of a matching receive.
32-
Standard,
33-
// Ready mode: Send operations may be started only if the matching receive is
34-
// already started.
35-
Ready,
36-
// Synchronous mode: Send operations complete successfully only if a matching
37-
// receive is started, and the receive operation has started to receive the
38-
// message sent.
39-
Synchronous,
40-
};
22+
23+
namespace KokkosComm::mpi {
24+
// Standard mode: MPI implementation decides whether outgoing messages will
25+
// be buffered. Send operations can be started whether or not a matching
26+
// receive has been started. They may complete before a matching receive is
27+
// started. Standard mode is non-local: successful completion of the send
28+
// operation may depend on the occurrence of a matching receive.
29+
struct CommModeStandard {};
30+
31+
// Ready mode: Send operations may be started only if the matching receive is
32+
// already started.
33+
struct CommModeReady {};
34+
35+
// Synchronous mode: Send operations complete successfully only if a matching
36+
// receive is started, and the receive operation has started to receive the
37+
// message sent.
38+
struct CommModeSynchronous {};
39+
40+
// Default mode: lets the user override the send operations behavior at
41+
// compile-time. E.g., this can be set to mode "Synchronous" for debug
42+
// builds by defining KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE.
43+
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
44+
using DefaultCommMode = CommModeSynchronous;
45+
#else
46+
using DefaultCommMode = CommModeStandard;
47+
#endif
48+
49+
template <typename T>
50+
struct is_communication_mode : std::false_type {};
51+
52+
template <>
53+
struct is_communication_mode<CommModeStandard> : std::true_type {};
54+
55+
template <>
56+
struct is_communication_mode<CommModeSynchronous> : std::true_type {};
57+
58+
template <>
59+
struct is_communication_mode<CommModeReady> : std::true_type {};
60+
61+
template <typename T>
62+
inline constexpr bool is_communication_mode_v = is_communication_mode<T>::value;
63+
64+
template <typename T>
65+
concept CommunicationMode = is_communication_mode_v<T>;
66+
4167
} // namespace KokkosComm::mpi

src/mpi/KokkosComm_mpi_isend.hpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,18 @@ namespace KokkosComm {
2626

2727
namespace Impl {
2828

29-
template <mpi::CommMode SendMode, KokkosExecutionSpace ExecSpace, KokkosView SendView>
30-
Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag) {
29+
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, mpi::CommunicationMode SendMode>
30+
Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag, SendMode) {
3131
auto mpi_isend_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag,
3232
MPI_Comm mpi_comm, MPI_Request *mpi_req) {
33-
if constexpr (SendMode == mpi::CommMode::Standard) {
33+
if constexpr (std::is_same_v<SendMode, mpi::CommModeStandard>) {
3434
MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req);
35-
} else if constexpr (SendMode == mpi::CommMode::Ready) {
35+
} else if constexpr (std::is_same_v<SendMode, mpi::CommModeReady>) {
3636
MPI_Irsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req);
37-
} else if constexpr (SendMode == mpi::CommMode::Synchronous) {
37+
} else if constexpr (std::is_same_v<SendMode, mpi::CommModeSynchronous>) {
3838
MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req);
39-
} else if constexpr (SendMode == mpi::CommMode::Default) {
40-
#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE
41-
MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req);
42-
#else
43-
MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req);
44-
#endif
39+
} else {
40+
static_assert(std::is_void_v<SendMode>, "unexpected communication mode");
4541
}
4642
};
4743

@@ -67,17 +63,22 @@ Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int
6763
template <KokkosExecutionSpace ExecSpace, KokkosView SendView>
6864
struct Send<SendView, ExecSpace, Mpi> {
6965
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag) {
70-
return isend_impl<mpi::CommMode::Standard, ExecSpace, SendView>(h, sv, dest, tag);
66+
return isend_impl<ExecSpace, SendView>(h, sv, dest, tag, mpi::DefaultCommMode{});
7167
}
7268
};
7369

7470
} // namespace Impl
7571

7672
namespace mpi {
7773

78-
template <CommMode SendMode, KokkosExecutionSpace ExecSpace, KokkosView SendView>
74+
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, CommunicationMode SendMode>
75+
Req<Mpi> isend(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag, SendMode) {
76+
return KokkosComm::Impl::isend_impl<ExecSpace, SendView>(h, sv, dest, tag, SendMode{});
77+
}
78+
79+
template <KokkosExecutionSpace ExecSpace, KokkosView SendView>
7980
Req<Mpi> isend(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag) {
80-
return KokkosComm::Impl::isend_impl<SendMode, ExecSpace, SendView>(h, sv, dest, tag);
81+
return isend<ExecSpace, SendView>(h, sv, dest, tag, DefaultCommMode{});
8182
}
8283

8384
template <KokkosView SendView>

src/mpi/KokkosComm_mpi_send.hpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,21 @@
2424

2525
namespace KokkosComm::mpi {
2626

27-
template <CommunicationMode SendMode, KokkosView SendView>
28-
void send(const SendMode &, const SendView &sv, int dest, int tag, MPI_Comm comm) {
27+
template <KokkosView SendView, CommunicationMode SendMode>
28+
void send(const SendView &sv, int dest, int tag, MPI_Comm comm, SendMode) {
2929
Kokkos::Tools::pushRegion("KokkosComm::Impl::send");
3030
using KCT = typename KokkosComm::Traits<SendView>;
3131

3232
auto mpi_send_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag,
3333
MPI_Comm mpi_comm) {
34-
if constexpr (std::is_same_v<SendMode, StandardCommMode>) {
34+
if constexpr (std::is_same_v<SendMode, CommModeStandard>) {
3535
MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
36-
} else if constexpr (std::is_same_v<SendMode, ReadyCommMode>) {
36+
} else if constexpr (std::is_same_v<SendMode, CommModeReady>) {
3737
MPI_Rsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
38-
} else if constexpr (std::is_same_v<SendMode, SynchronousCommMode>) {
38+
} else if constexpr (std::is_same_v<SendMode, CommModeSynchronous>) {
3939
MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
40+
} else {
41+
static_assert(std::is_void_v<SendMode>, "unexpected communication mode");
4042
}
4143
};
4244

@@ -50,25 +52,22 @@ void send(const SendMode &, const SendView &sv, int dest, int tag, MPI_Comm comm
5052
Kokkos::Tools::popRegion();
5153
}
5254

53-
template <KokkosView SendView>
54-
void send(const SendView &sv, int dest, int tag, MPI_Comm comm) {
55-
send(KokkosComm::DefaultCommMode(), sv, dest, tag, comm);
56-
}
57-
58-
template <CommunicationMode SendMode, KokkosExecutionSpace ExecSpace, KokkosView SendView>
59-
void send(const SendMode &, const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) {
55+
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, CommunicationMode SendMode>
56+
void send(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm, SendMode) {
6057
Kokkos::Tools::pushRegion("KokkosComm::Impl::send");
6158

6259
using Packer = typename KokkosComm::PackTraits<SendView>::packer_type;
6360

6461
auto mpi_send_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag,
6562
MPI_Comm mpi_comm) {
66-
if constexpr (std::is_same_v<SendMode, StandardCommMode>) {
63+
if constexpr (std::is_same_v<SendMode, CommModeStandard>) {
6764
MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
68-
} else if constexpr (std::is_same_v<SendMode, ReadyCommMode>) {
65+
} else if constexpr (std::is_same_v<SendMode, CommModeReady>) {
6966
MPI_Rsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
70-
} else if constexpr (std::is_same_v<SendMode, SynchronousCommMode>) {
67+
} else if constexpr (std::is_same_v<SendMode, CommModeSynchronous>) {
7168
MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm);
69+
} else {
70+
static_assert(std::is_void_v<SendMode>, "unexpected communication mode");
7271
}
7372
};
7473

@@ -84,4 +83,9 @@ void send(const SendMode &, const ExecSpace &space, const SendView &sv, int dest
8483
Kokkos::Tools::popRegion();
8584
}
8685

86+
template <KokkosExecutionSpace ExecSpace, KokkosView SendView>
87+
void send(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) {
88+
send(space, sv, dest, tag, comm, DefaultCommMode{});
89+
}
90+
8791
} // namespace KokkosComm::mpi

unit_tests/mpi/test_isendrecv.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
namespace {
2323

24+
using namespace KokkosComm::mpi;
25+
2426
template <typename T>
2527
class IsendRecv : public testing::Test {
2628
public:
@@ -31,9 +33,9 @@ using ScalarTypes =
3133
::testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
3234
TYPED_TEST_SUITE(IsendRecv, ScalarTypes);
3335

34-
template <KokkosComm::mpi::CommMode IsendMode, typename Scalar>
36+
template <CommunicationMode IsendMode, typename Scalar>
3537
void isend_comm_mode_1d_contig() {
36-
if (IsendMode == KokkosComm::mpi::CommMode::Ready) {
38+
if constexpr (std::is_same_v<IsendMode, CommModeReady>) {
3739
GTEST_SKIP() << "Skipping test for ready-mode send";
3840
}
3941

@@ -48,7 +50,7 @@ void isend_comm_mode_1d_contig() {
4850
int dst = 1;
4951
Kokkos::parallel_for(
5052
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
51-
KokkosComm::Req req = KokkosComm::mpi::isend<IsendMode>(h, a, dst, 0);
53+
KokkosComm::Req req = KokkosComm::mpi::isend(h, a, dst, 0, IsendMode{});
5254
KokkosComm::wait(req);
5355
} else if (1 == h.rank()) {
5456
int src = 0;
@@ -60,9 +62,9 @@ void isend_comm_mode_1d_contig() {
6062
}
6163
}
6264

63-
template <KokkosComm::mpi::CommMode IsendMode, typename Scalar>
65+
template <CommunicationMode IsendMode, typename Scalar>
6466
void isend_comm_mode_1d_noncontig() {
65-
if (IsendMode == KokkosComm::mpi::CommMode::Ready) {
67+
if constexpr (std::is_same_v<IsendMode, CommModeReady>) {
6668
GTEST_SKIP() << "Skipping test for ready-mode send";
6769
}
6870

@@ -79,7 +81,7 @@ void isend_comm_mode_1d_noncontig() {
7981
int dst = 1;
8082
Kokkos::parallel_for(
8183
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
82-
KokkosComm::Req req = KokkosComm::mpi::isend<IsendMode>(h, a, dst, 0);
84+
KokkosComm::Req req = KokkosComm::mpi::isend(h, a, dst, 0, IsendMode{});
8385
KokkosComm::wait(req);
8486
} else if (1 == h.rank()) {
8587
int src = 0;
@@ -92,27 +94,25 @@ void isend_comm_mode_1d_noncontig() {
9294
}
9395

9496
TYPED_TEST(IsendRecv, 1D_contig_standard) {
95-
isend_comm_mode_1d_contig<KokkosComm::mpi::CommMode::Standard, typename TestFixture::Scalar>();
97+
isend_comm_mode_1d_contig<CommModeStandard, typename TestFixture::Scalar>();
9698
}
9799

98-
TYPED_TEST(IsendRecv, 1D_contig_ready) {
99-
isend_comm_mode_1d_contig<KokkosComm::mpi::CommMode::Ready, typename TestFixture::Scalar>();
100-
}
100+
TYPED_TEST(IsendRecv, 1D_contig_ready) { isend_comm_mode_1d_contig<CommModeReady, typename TestFixture::Scalar>(); }
101101

102102
TYPED_TEST(IsendRecv, 1D_contig_synchronous) {
103-
isend_comm_mode_1d_contig<KokkosComm::mpi::CommMode::Synchronous, typename TestFixture::Scalar>();
103+
isend_comm_mode_1d_contig<CommModeSynchronous, typename TestFixture::Scalar>();
104104
}
105105

106106
TYPED_TEST(IsendRecv, 1D_noncontig_standard) {
107-
isend_comm_mode_1d_noncontig<KokkosComm::mpi::CommMode::Standard, typename TestFixture::Scalar>();
107+
isend_comm_mode_1d_noncontig<CommModeStandard, typename TestFixture::Scalar>();
108108
}
109109

110110
TYPED_TEST(IsendRecv, 1D_noncontig_ready) {
111-
isend_comm_mode_1d_noncontig<KokkosComm::mpi::CommMode::Ready, typename TestFixture::Scalar>();
111+
isend_comm_mode_1d_noncontig<CommModeReady, typename TestFixture::Scalar>();
112112
}
113113

114114
TYPED_TEST(IsendRecv, 1D_noncontig_synchronous) {
115-
isend_comm_mode_1d_noncontig<KokkosComm::mpi::CommMode::Synchronous, typename TestFixture::Scalar>();
115+
isend_comm_mode_1d_noncontig<CommModeSynchronous, typename TestFixture::Scalar>();
116116
}
117117

118118
} // namespace

0 commit comments

Comments
 (0)