Skip to content

Commit da51abe

Browse files
committed
test: refactor + add new MD view helpers
Refactored implementation of view builder helpers and added new ones, all now MD-capable. - Refactored `ViewBuilder` struct into a free function that allows building views of arbitrary type and dimension, contiguous or not. - Added view initialization and error counting helpers. Basically, these were the problematic pieces of code that forced us to have dedicated test functions for each view dimension we wanted to test. These now abstract this away and are MD-capable. Theses changes make it possible to write only the high-level test function, and pass different kind of views to it (which better mirrors user code that Kokkos Comm is designed to handle correctly). It's now trivial to test views of higher dimensions (>2), contiguous or not, provided that we support them (which is not the case when they aren't contiguous, due to limitations caused by a similar implementation strategy as what was refactored here).
1 parent e594754 commit da51abe

File tree

3 files changed

+146
-101
lines changed

3 files changed

+146
-101
lines changed

unit_tests/test_sendrecv.cpp

Lines changed: 23 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
#include <gtest/gtest.h>
55
#include <KokkosComm/KokkosComm.hpp>
66

7-
#include "view_builder.hpp"
7+
#include "view_utils.hpp"
88
#if defined(KOKKOSCOMM_ENABLE_NCCL)
99
#include "nccl/utils.hpp"
1010
#endif
1111

1212
namespace {
1313

14-
using Ex = Kokkos::DefaultExecutionSpace;
15-
using Co = KokkosComm::DefaultCommunicationSpace;
16-
1714
template <typename T>
1815
class SendRecv : public testing::Test {
1916
public:
@@ -28,56 +25,18 @@ using ScalarTypes =
2825
#endif
2926
TYPED_TEST_SUITE(SendRecv, ScalarTypes);
3027

31-
template <KokkosComm::KokkosView View1D>
32-
void test_1d(const View1D &v) {
33-
static_assert(View1D::rank == 1, "");
34-
using Scalar = typename View1D::non_const_value_type;
35-
36-
#if defined(KOKKOSCOMM_ENABLE_NCCL)
37-
auto nccl_ctx = test_utils::nccl::Ctx::init();
38-
auto raw_comm = nccl_ctx.comm();
39-
#else
40-
auto raw_comm = MPI_COMM_WORLD;
41-
#endif
42-
auto exec = Ex();
43-
auto comm = KokkosComm::Communicator<Co, Ex>::duplicate(raw_comm, exec).value();
44-
const int size = comm.size();
45-
const int rank = comm.rank();
46-
if (size < 2) {
47-
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
48-
}
49-
const int src = 0;
50-
const int dst = 1;
51-
52-
if (rank == src) {
53-
Kokkos::parallel_for(
54-
v.extent(0), KOKKOS_LAMBDA(const int i) { v(i) = i; }
55-
);
56-
KokkosComm::send(comm, v, dst).wait();
57-
} else if (rank == dst) {
58-
KokkosComm::recv(comm, v, src).wait();
59-
60-
int errs;
61-
Kokkos::parallel_reduce(
62-
v.extent(0), KOKKOS_LAMBDA(const int i, int &lsum) { lsum += v(i) != Scalar(i); }, errs
63-
);
64-
ASSERT_EQ(errs, 0);
65-
}
66-
}
67-
68-
template <KokkosComm::KokkosView View2D>
69-
void test_2d(const View2D &v) {
70-
static_assert(View2D::rank == 2, "");
71-
using Scalar = typename View2D::non_const_value_type;
72-
28+
template <KokkosComm::KokkosView View>
29+
void test_core_send_recv(const View& v) {
30+
using Exec = Kokkos::DefaultExecutionSpace;
31+
using Comm = KokkosComm::DefaultCommunicationSpace;
7332
#if defined(KOKKOSCOMM_ENABLE_NCCL)
7433
auto nccl_ctx = test_utils::nccl::Ctx::init();
7534
auto raw_comm = nccl_ctx.comm();
7635
#else
7736
auto raw_comm = MPI_COMM_WORLD;
7837
#endif
79-
auto exec = Ex();
80-
auto comm = KokkosComm::Communicator<Co, Ex>::duplicate(raw_comm, exec).value();
38+
auto exec = Exec();
39+
auto comm = KokkosComm::Communicator<Comm, Exec>::duplicate(raw_comm, exec).value();
8140
const int size = comm.size();
8241
const int rank = comm.rank();
8342
if (size < 2) {
@@ -86,36 +45,32 @@ void test_2d(const View2D &v) {
8645
const int src = 0;
8746
const int dst = 1;
8847

89-
using Policy = Kokkos::MDRangePolicy<Kokkos::Rank<2>>;
90-
Policy policy(exec, {0, 0}, {v.extent(0), v.extent(1)});
91-
9248
if (rank == src) {
93-
Kokkos::parallel_for(
94-
policy, KOKKOS_LAMBDA(const int i, const int j) { v(i, j) = i * v.extent(0) + j; }
95-
);
49+
test_utils::init_view(exec, v);
9650
exec.fence();
97-
9851
KokkosComm::send(comm, v, dst).wait();
9952
} else if (rank == dst) {
10053
KokkosComm::recv(comm, v, src).wait();
101-
102-
int errs;
103-
Kokkos::parallel_reduce(
104-
policy, KOKKOS_LAMBDA(const int i, const int j, int &lsum) { lsum += v(i, j) != Scalar(i * v.extent(0) + j); },
105-
errs
106-
);
107-
exec.fence();
54+
int errs = test_utils::count_errors(v);
10855
ASSERT_EQ(errs, 0);
10956
}
11057
}
11158

112-
TYPED_TEST(SendRecv, 1D_contig) {
113-
auto v = ViewBuilder<typename TestFixture::Scalar, 1>::view(contig{}, "v", 1013);
114-
test_1d(v);
59+
TYPED_TEST(SendRecv, Contig1D) {
60+
auto v = test_utils::build_view<typename TestFixture::Scalar, 1>(test_utils::Contig{}, "v", 1013);
61+
test_core_send_recv(v);
62+
}
63+
TYPED_TEST(SendRecv, NonContig1D) {
64+
auto v = test_utils::build_view<typename TestFixture::Scalar, 1>(test_utils::NonContig{}, "v", 1013);
65+
test_core_send_recv(v);
66+
}
67+
TYPED_TEST(SendRecv, Contig2D) {
68+
auto v = test_utils::build_view<typename TestFixture::Scalar, 2>(test_utils::Contig{}, "v", 137, 17);
69+
test_core_send_recv(v);
11570
}
116-
TYPED_TEST(SendRecv, 2D_contig) {
117-
auto v = ViewBuilder<typename TestFixture::Scalar, 2>::view(contig{}, "v", 137, 17);
118-
test_2d(v);
71+
TYPED_TEST(SendRecv, NonContig2D) {
72+
auto v = test_utils::build_view<typename TestFixture::Scalar, 2>(test_utils::NonContig{}, "v", 137, 17);
73+
test_core_send_recv(v);
11974
}
12075

12176
} // namespace

unit_tests/view_builder.hpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

unit_tests/view_utils.hpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#pragma once
5+
6+
#include <cstddef>
7+
#include <string>
8+
#include <utility>
9+
10+
#include <Kokkos_Core.hpp>
11+
12+
namespace test_utils {
13+
14+
struct Contig {};
15+
struct NonContig {};
16+
17+
namespace Impl {
18+
19+
template <typename T, size_t N>
20+
struct AddPtrs {
21+
using type = typename AddPtrs<T*, N - 1>::type;
22+
};
23+
template <typename T>
24+
struct AddPtrs<T, 0> {
25+
using type = T;
26+
};
27+
template <typename T, size_t N>
28+
using Stars = typename AddPtrs<T, N>::type;
29+
30+
template <size_t Rank>
31+
auto all_then_one() {
32+
return [&]<size_t... Is>(std::index_sequence<Is...>) {
33+
return std::make_tuple((void(Is), Kokkos::ALL)..., 1);
34+
}(std::make_index_sequence<Rank>{});
35+
}
36+
template <typename ViewType, typename Tuple, size_t... Is>
37+
auto apply_subview(ViewType& v, Tuple&& t, std::index_sequence<Is...>) {
38+
return Kokkos::subview(v, std::get<Is>(std::forward<Tuple>(t))...);
39+
}
40+
41+
template <typename View, size_t... Is>
42+
struct InitFunctor {
43+
View v;
44+
std::array<int, sizeof...(Is)> exts;
45+
46+
KOKKOS_FUNCTION void operator()(decltype(static_cast<int>(std::declval<View>().extent(Is)))... idxs) const {
47+
int val = 0;
48+
int ids[] = {int(idxs)...};
49+
((val = ids[Is] + exts[Is] * val), ...);
50+
[&]<size_t... Js>(std::index_sequence<Js...>) { v(ids[Js]...) = val; }(std::index_sequence<Is...>{});
51+
}
52+
};
53+
54+
template <typename View, size_t... Is>
55+
struct CountFunctor {
56+
using Scalar = typename View::value_type;
57+
View v;
58+
std::array<int, sizeof...(Is)> exts;
59+
60+
KOKKOS_FUNCTION void operator()(
61+
decltype(static_cast<int>(std::declval<View>().extent(Is)))... idxs, int& lsum
62+
) const {
63+
int val = 0;
64+
int ids[] = {int(idxs)...};
65+
((val = ids[Is] + exts[Is] * val), ...);
66+
[&]<size_t... Js>(std::index_sequence<Js...>) {
67+
lsum += v(ids[Js]...) != Scalar(val);
68+
}(std::index_sequence<Is...>{});
69+
}
70+
};
71+
72+
} // namespace Impl
73+
74+
template <typename T, size_t R, typename... Extents>
75+
auto build_view(Contig, const std::string& name, Extents... exts) {
76+
static_assert(sizeof...(exts) == R, "Number of extents must match Rank");
77+
return Kokkos::View<Impl::Stars<T, R>>(name, exts...);
78+
}
79+
template <typename T, size_t R, typename... Extents>
80+
auto build_view(NonContig, const std::string& name, Extents... exts) {
81+
static_assert(sizeof...(exts) == R, "Number of extents must match Rank");
82+
Kokkos::View<Impl::Stars<T, R + 1>, Kokkos::LayoutRight> v(name, exts..., 2);
83+
auto args = Impl::all_then_one<R>();
84+
return Impl::apply_subview(v, args, std::make_index_sequence<R + 1>{});
85+
}
86+
87+
template <typename Exec, typename View>
88+
auto init_view(const Exec& exec, const View& v) -> void {
89+
constexpr size_t R = v.rank();
90+
if constexpr (R == 1) {
91+
Kokkos::parallel_for(Kokkos::RangePolicy<Exec>(exec, 0, v.extent(0)), KOKKOS_LAMBDA(const int i) { v(i) = i; });
92+
} else {
93+
[&]<size_t... Is>(std::index_sequence<Is...>) {
94+
std::array<int, R> exts = {static_cast<int>(v.extent(Is))...};
95+
Kokkos::parallel_for(
96+
Kokkos::MDRangePolicy<Exec, Kokkos::Rank<R>>(exec, {(void(Is), 0)...}, {static_cast<int>(v.extent(Is))...}),
97+
Impl::InitFunctor<View, Is...>{v, exts}
98+
);
99+
}(std::make_index_sequence<R>{});
100+
}
101+
exec.fence();
102+
}
103+
104+
template <typename View>
105+
auto count_errors(const View& v) -> int {
106+
constexpr size_t R = v.rank();
107+
int errs = 0;
108+
using Scalar = typename View::value_type;
109+
if constexpr (R == 1) {
110+
Kokkos::parallel_reduce(v.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += v(i) != Scalar(i); }, errs);
111+
} else {
112+
[&]<size_t... Is>(std::index_sequence<Is...>) {
113+
std::array<int, R> exts = {static_cast<int>(v.extent(Is))...};
114+
Kokkos::parallel_reduce(
115+
Kokkos::MDRangePolicy<Kokkos::Rank<R>>({(void(Is), 0)...}, {static_cast<int>(v.extent(Is))...}),
116+
Impl::CountFunctor<View, Is...>{v, exts}, errs
117+
);
118+
}(std::make_index_sequence<R>{});
119+
}
120+
return errs;
121+
}
122+
123+
} // namespace test_utils

0 commit comments

Comments
 (0)