Skip to content

Commit f98e15a

Browse files
committed
tests: refactor all tests to pass with the new Communicators
Signed-off-by: Gabriel Dos Santos <gabriel.dossantos@cea.fr>
1 parent 40d9bd8 commit f98e15a

17 files changed

+341
-269
lines changed

perf_tests/mpi/test_2d_halo.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
22
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
33

4+
#include "KokkosComm/fwd.hpp"
5+
#include "KokkosComm/mpi/mpi_space.hpp"
46
#include "test_utils.hpp"
57

68
#include <iostream>
@@ -10,9 +12,10 @@
1012
void noop(benchmark::State, MPI_Comm) {}
1113

1214
template <typename Space, typename View>
13-
void send_recv(benchmark::State&, MPI_Comm comm, const Space& space, int nx, int ny, int rx, int ry, int rs,
14-
const View& v) {
15-
KokkosComm::Handle<> h{space, comm};
15+
void send_recv(
16+
benchmark::State&, MPI_Comm comm, const Space& space, int nx, int ny, int rx, int ry, int rs, const View& v
17+
) {
18+
auto h = KokkosComm::Communicator<KokkosComm::MpiSpace, Space>::from_raw(comm, space).value();
1619

1720
// 2D index of nbrs in minus and plus direction (periodic)
1821
const int xm1 = (rx + rs - 1) % rs;
@@ -73,8 +76,9 @@ void benchmark_2dhalo(benchmark::State& state) {
7376
// grid of elements, each with 3 properties, and a radius-1 halo
7477
grid_type grid("", nx + 2, ny + 2, nprops);
7578
while (state.KeepRunning()) {
76-
do_iteration(state, MPI_COMM_WORLD, send_recv<Kokkos::DefaultExecutionSpace, grid_type>, space, nx, ny, rx, ry,
77-
rs, grid);
79+
do_iteration(
80+
state, MPI_COMM_WORLD, send_recv<Kokkos::DefaultExecutionSpace, grid_type>, space, nx, ny, rx, ry, rs, grid
81+
);
7882
}
7983
} else {
8084
while (state.KeepRunning()) {

perf_tests/mpi/test_osu_latency.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
// Copyright (c) 2002-2024 the Network-Based Computing Laboratory
66
// (NBCL), The Ohio State University.
77

8+
#include "KokkosComm/mpi/mpi_space.hpp"
89
#include "test_utils.hpp"
910

1011
#include <KokkosComm/KokkosComm.hpp>
1112

1213
template <typename Space, typename View>
13-
void osu_latency_Kokkos_Comm_sendrecv(benchmark::State &, MPI_Comm, KokkosComm::Handle<> &h, const View &v) {
14+
void osu_latency_Kokkos_Comm_sendrecv(
15+
benchmark::State &, MPI_Comm, KokkosComm::Communicator<KokkosComm::MpiSpace, Space> &h, const View &v
16+
) {
1417
if (h.rank() == 0) {
1518
KokkosComm::wait(KokkosComm::send(h, v, 1));
1619
} else if (h.rank() == 1) {
@@ -19,7 +22,10 @@ void osu_latency_Kokkos_Comm_sendrecv(benchmark::State &, MPI_Comm, KokkosComm::
1922
}
2023

2124
void benchmark_osu_latency_KokkosComm_sendrecv(benchmark::State &state) {
22-
KokkosComm::Handle<> h;
25+
auto h = KokkosComm::Communicator<KokkosComm::MpiSpace, Kokkos::DefaultExecutionSpace>::from_raw(
26+
MPI_COMM_WORLD, Kokkos::DefaultExecutionSpace()
27+
)
28+
.value();
2329
if (h.size() != 2) {
2430
state.SkipWithError("benchmark_osu_latency_KokkosComm needs exactly 2 ranks");
2531
}
@@ -28,14 +34,15 @@ void benchmark_osu_latency_KokkosComm_sendrecv(benchmark::State &state) {
2834
view_type a("A", state.range(0));
2935

3036
while (state.KeepRunning()) {
31-
do_iteration(state, h.mpi_comm(), osu_latency_Kokkos_Comm_sendrecv<Kokkos::DefaultExecutionSpace, view_type>, h, a);
37+
do_iteration(state, h.comm(), osu_latency_Kokkos_Comm_sendrecv<Kokkos::DefaultExecutionSpace, view_type>, h, a);
3238
}
3339
state.counters["bytes"] = a.size() * 2;
3440
}
3541

3642
template <typename Space, typename View>
37-
void osu_latency_Kokkos_Comm_mpi_sendrecv(benchmark::State &, MPI_Comm comm, const Space &space, int rank,
38-
const View &v) {
43+
void osu_latency_Kokkos_Comm_mpi_sendrecv(
44+
benchmark::State &, MPI_Comm comm, const Space &space, int rank, const View &v
45+
) {
3946
if (rank == 0) {
4047
KokkosComm::mpi::send(space, v, 1, 0, comm);
4148
} else if (rank == 1) {
@@ -56,8 +63,10 @@ void benchmark_osu_latency_Kokkos_Comm_mpi_sendrecv(benchmark::State &state) {
5663
view_type a("A", state.range(0));
5764

5865
while (state.KeepRunning()) {
59-
do_iteration(state, MPI_COMM_WORLD, osu_latency_Kokkos_Comm_mpi_sendrecv<Kokkos::DefaultExecutionSpace, view_type>,
60-
space, rank, a);
66+
do_iteration(
67+
state, MPI_COMM_WORLD, osu_latency_Kokkos_Comm_mpi_sendrecv<Kokkos::DefaultExecutionSpace, view_type>, space,
68+
rank, a
69+
);
6170
}
6271
state.counters["bytes"] = a.size() * 2;
6372
}
@@ -66,12 +75,16 @@ template <typename View>
6675
void osu_latency_MPI_isendirecv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
6776
MPI_Request sendreq, recvreq;
6877
if (rank == 0) {
69-
MPI_Irecv(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
70-
&recvreq);
78+
MPI_Irecv(
79+
v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
80+
&recvreq
81+
);
7182
MPI_Wait(&recvreq, MPI_STATUS_IGNORE);
7283
} else if (rank == 1) {
73-
MPI_Isend(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm,
74-
&sendreq);
84+
MPI_Isend(
85+
v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm,
86+
&sendreq
87+
);
7588
MPI_Wait(&sendreq, MPI_STATUS_IGNORE);
7689
}
7790
}
@@ -96,8 +109,10 @@ void benchmark_osu_latency_MPI_isendirecv(benchmark::State &state) {
96109
template <typename View>
97110
void osu_latency_MPI_sendrecv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
98111
if (rank == 0) {
99-
MPI_Recv(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
100-
MPI_STATUS_IGNORE);
112+
MPI_Recv(
113+
v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
114+
MPI_STATUS_IGNORE
115+
);
101116
} else if (rank == 1) {
102117
MPI_Send(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm);
103118
}

unit_tests/mpi/test_barrier.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
namespace {
99

10-
TEST(Barrier, 0) {
11-
KokkosComm::Handle h;
12-
KokkosComm::mpi::barrier(h.mpi_comm());
13-
}
10+
TEST(Barrier, 0) { KokkosComm::mpi::barrier(MPI_COMM_WORLD); }
1411

1512
} // namespace

unit_tests/mpi/test_isendrecv.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
namespace {
1111

12+
using Ex = Kokkos::DefaultExecutionSpace;
13+
using Co = KokkosComm::DefaultCommunicationSpace;
14+
1215
using namespace KokkosComm::mpi;
1316

1417
template <typename T>
@@ -29,22 +32,24 @@ void isend_comm_mode_1d_contig() {
2932

3033
Kokkos::View<Scalar*> a("a", 1000);
3134

32-
KokkosComm::Handle<> h;
35+
auto h = KokkosComm::Communicator<Co, Ex>::from_raw(MPI_COMM_WORLD, Ex()).value();
3336
if (h.size() < 2) {
3437
GTEST_SKIP() << "Requires >= 2 ranks (" << h.size() << " provided)";
3538
}
3639

3740
if (0 == h.rank()) {
3841
int dst = 1;
3942
Kokkos::parallel_for(
40-
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
43+
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }
44+
);
4145
KokkosComm::mpi::isend(h, a, dst, 0, IsendMode{}).wait();
4246
} else if (1 == h.rank()) {
4347
int src = 0;
44-
KokkosComm::mpi::recv(h.space(), a, src, 0, h.mpi_comm());
48+
KokkosComm::mpi::recv(h.exec(), a, src, 0, h.comm());
4549
int errs;
4650
Kokkos::parallel_reduce(
47-
a.extent(0), KOKKOS_LAMBDA(const int& i, int& lsum) { lsum += a(i) != Scalar(i); }, errs);
51+
a.extent(0), KOKKOS_LAMBDA(const int& i, int& lsum) { lsum += a(i) != Scalar(i); }, errs
52+
);
4853
ASSERT_EQ(errs, 0);
4954
}
5055
}
@@ -59,22 +64,24 @@ void isend_comm_mode_1d_noncontig() {
5964
Kokkos::View<Scalar**, Kokkos::LayoutRight> b("a", 10, 10);
6065
auto a = Kokkos::subview(b, Kokkos::ALL, 2); // take column 2 (non-contiguous)
6166

62-
KokkosComm::Handle<> h;
67+
auto h = KokkosComm::Communicator<Co, Ex>::from_raw(MPI_COMM_WORLD, Ex()).value();
6368
if (h.size() < 2) {
6469
GTEST_SKIP() << "Requires >= 2 ranks (" << h.size() << " provided)";
6570
}
6671

6772
if (0 == h.rank()) {
6873
int dst = 1;
6974
Kokkos::parallel_for(
70-
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
75+
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }
76+
);
7177
KokkosComm::mpi::isend(h, a, dst, 0, IsendMode{}).wait();
7278
} else if (1 == h.rank()) {
7379
int src = 0;
74-
KokkosComm::mpi::recv(h.space(), a, src, 0, h.mpi_comm());
80+
KokkosComm::mpi::recv(h.exec(), a, src, 0, h.comm());
7581
int errs;
7682
Kokkos::parallel_reduce(
77-
a.extent(0), KOKKOS_LAMBDA(const int& i, int& lsum) { lsum += a(i) != Scalar(i); }, errs);
83+
a.extent(0), KOKKOS_LAMBDA(const int& i, int& lsum) { lsum += a(i) != Scalar(i); }, errs
84+
);
7885
ASSERT_EQ(errs, 0);
7986
}
8087
}

unit_tests/nccl/test_allgather.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010

1111
namespace {
1212

13-
using ExecSpace = Kokkos::Cuda;
14-
using CommSpace = KokkosComm::Experimental::NcclSpace;
15-
1613
template <typename T>
1714
class AllGather : public testing::Test {
1815
public:
@@ -24,46 +21,51 @@ TYPED_TEST_SUITE(AllGather, ScalarTypes);
2421

2522
template <typename Scalar>
2623
auto allgather_0d() -> void {
27-
auto nccl_ctx = test_utils::nccl::Ctx::init();
28-
ExecSpace space(nccl_ctx.stream());
29-
KokkosComm::Handle<ExecSpace, CommSpace> h(space, nccl_ctx.comm());
30-
int rank = h.rank();
31-
int size = h.size();
24+
auto nccl_ctx = test_utils::nccl::Ctx::init();
25+
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
26+
const auto comm = nccl_ctx.comm();
27+
const int size = nccl_ctx.size();
28+
const int rank = nccl_ctx.rank();
29+
const int root = 0;
3230

3331
Kokkos::View<Scalar> sv("sv");
3432
Kokkos::View<Scalar *> rv("rv", size);
3533

3634
// Prepare send view, 1 element per sender: their rank
3735
Kokkos::parallel_for(
38-
Kokkos::RangePolicy(space, 0, sv.extent(0)), KOKKOS_LAMBDA(const int) { sv() = rank; });
36+
Kokkos::RangePolicy(exec, 0, sv.extent(0)), KOKKOS_LAMBDA(const int) { sv() = rank; }
37+
);
38+
3939
// Using the same execution space for both operations lets us not need an explicit `fence`
40-
auto req = KokkosComm::Experimental::allgather(h, sv, rv);
41-
KokkosComm::wait(req);
40+
KokkosComm::Experimental::nccl::allgather(exec, sv, rv, comm).wait();
4241

4342
int errs;
4443
Kokkos::parallel_reduce(
45-
rv.extent(0), KOKKOS_LAMBDA(const int src, int &lsum) { lsum += rv(src) != src; }, errs);
44+
rv.extent(0), KOKKOS_LAMBDA(const int src, int &lsum) { lsum += rv(src) != src; }, errs
45+
);
4646
EXPECT_EQ(errs, 0);
4747
}
4848

4949
template <typename Scalar>
5050
auto allgather_contig_1d() -> void {
51-
auto nccl_ctx = test_utils::nccl::Ctx::init();
52-
ExecSpace space(nccl_ctx.stream());
53-
KokkosComm::Handle<ExecSpace, CommSpace> h(space, nccl_ctx.comm());
54-
int rank = h.rank();
55-
int size = h.size();
51+
auto nccl_ctx = test_utils::nccl::Ctx::init();
52+
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
53+
const auto comm = nccl_ctx.comm();
54+
const int size = nccl_ctx.size();
55+
const int rank = nccl_ctx.rank();
56+
const int root = 0;
5657

5758
const int n_contrib = 100;
5859
Kokkos::View<Scalar *> sv("sv", n_contrib);
5960
Kokkos::View<Scalar *> rv("rv", size * n_contrib);
6061

6162
// Prepare send view
6263
Kokkos::parallel_for(
63-
Kokkos::RangePolicy(space, 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });
64+
Kokkos::RangePolicy(exec, 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; }
65+
);
66+
6467
// Using the same execution space for both operations lets us not need an explicit `fence`
65-
auto req = KokkosComm::Experimental::allgather(h, sv, rv);
66-
KokkosComm::wait(req);
68+
KokkosComm::Experimental::allgather(exec, sv, rv, comm).wait();
6769

6870
int errs;
6971
Kokkos::parallel_reduce(
@@ -73,7 +75,8 @@ auto allgather_contig_1d() -> void {
7375
const int j = i % n_contrib;
7476
lsum += rv(i) != src + j;
7577
},
76-
errs);
78+
errs
79+
);
7780
EXPECT_EQ(errs, 0);
7881
}
7982

unit_tests/nccl/test_allreduce.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010

1111
namespace {
1212

13-
using ExecSpace = Kokkos::Cuda;
14-
using CommSpace = KokkosComm::Experimental::NcclSpace;
15-
1613
template <typename T>
1714
class AllReduce : public testing::Test {
1815
public:
@@ -24,51 +21,56 @@ TYPED_TEST_SUITE(AllReduce, ScalarTypes);
2421

2522
template <typename Scalar>
2623
auto allreduce_0d() -> void {
27-
auto nccl_ctx = test_utils::nccl::Ctx::init();
28-
ExecSpace space(nccl_ctx.stream());
29-
KokkosComm::Handle<ExecSpace, CommSpace> h(space, nccl_ctx.comm());
30-
int rank = h.rank();
31-
int size = h.size();
24+
auto nccl_ctx = test_utils::nccl::Ctx::init();
25+
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
26+
const auto comm = nccl_ctx.comm();
27+
const int size = nccl_ctx.size();
28+
const int rank = nccl_ctx.rank();
29+
const int root = 0;
3230

3331
Kokkos::View<Scalar> sv("sv");
34-
Kokkos::View<Scalar> rv("rv", size);
32+
Kokkos::View<Scalar> rv("rv");
3533

3634
// Prepare send buffer
3735
Kokkos::parallel_for(
38-
Kokkos::RangePolicy(space, 0, sv.extent(0)), KOKKOS_LAMBDA(const int) { sv() = rank; });
36+
Kokkos::RangePolicy(exec, 0, sv.extent(0)), KOKKOS_LAMBDA(const int) { sv() = rank; }
37+
);
38+
3939
// Using the same execution space for both operations lets us not need an explicit `fence`
40-
auto req = KokkosComm::Experimental::allreduce(h, sv, rv, KokkosComm::Sum{});
41-
KokkosComm::wait(req);
40+
KokkosComm::Experimental::nccl::allreduce(exec, sv, rv, ncclSum, comm).wait();
4241

4342
int errs;
4443
Kokkos::parallel_reduce(
45-
rv.extent(0), KOKKOS_LAMBDA(const int, int &lsum) { lsum += (rv() != size * (size - 1) / 2); }, errs);
44+
rv.extent(0), KOKKOS_LAMBDA(const int, int &lsum) { lsum += (rv() != size * (size - 1) / 2); }, errs
45+
);
4646
EXPECT_EQ(errs, 0);
4747
}
4848

4949
template <typename Scalar>
5050
auto allreduce_contig_1d() -> void {
51-
auto nccl_ctx = test_utils::nccl::Ctx::init();
52-
ExecSpace space(nccl_ctx.stream());
53-
KokkosComm::Handle<ExecSpace, CommSpace> h(space, nccl_ctx.comm());
54-
int rank = h.rank();
55-
int size = h.size();
56-
57-
int n_contrib = 10;
51+
auto nccl_ctx = test_utils::nccl::Ctx::init();
52+
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
53+
const auto comm = nccl_ctx.comm();
54+
const int size = nccl_ctx.size();
55+
const int rank = nccl_ctx.rank();
56+
const int root = 0;
57+
58+
const int n_contrib = 10;
5859
Kokkos::View<Scalar *> sv("sv", n_contrib);
59-
Kokkos::View<Scalar *> rv("rv", size);
60+
Kokkos::View<Scalar *> rv("rv", n_contrib);
6061

6162
// Prepare send buffer
6263
Kokkos::parallel_for(
63-
Kokkos::RangePolicy(space, 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });
64+
Kokkos::RangePolicy(exec, 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; }
65+
);
66+
6467
// Using the same execution space for both operations lets us not need an explicit `fence`
65-
auto req = KokkosComm::Experimental::allreduce(h, sv, rv, KokkosComm::Sum{});
66-
KokkosComm::wait(req);
68+
KokkosComm::Experimental::nccl::allreduce(exec, sv, rv, ncclSum, comm).wait();
6769

6870
int errs;
6971
Kokkos::parallel_reduce(
70-
rv.extent(0), KOKKOS_LAMBDA(const int i, int &lsum) { lsum += (rv(i) != size * (size - 1) / 2 + size * i); },
71-
errs);
72+
rv.extent(0), KOKKOS_LAMBDA(const int i, int &lsum) { lsum += (rv(i) != size * (size - 1) / 2 + size * i); }, errs
73+
);
7274
EXPECT_EQ(errs, 0);
7375
}
7476

0 commit comments

Comments
 (0)