Skip to content

Commit 46a5269

Browse files
committed
test: refactor + add new MD view helpers
Refactored implementation of view builder helpers and added new ones, all now MD-capable. - Refactored `ViewBuilder` struct into a free function that allows building views of arbitrary type and dimension, contiguous or not. - Added view initialization and error counting helpers. Basically, these were the problematic pieces of code that forced us to have dedicated test functions for each view dimension we wanted to test. These now abstract this away and are MD-capable. Theses changes make it possible to write only the high-level test function, and pass different kind of views to it (which better mirrors user code that Kokkos Comm is designed to handle correctly). It's now trivial to test views of higher dimensions (>2), contiguous or not, provided that we support them (which is not the case when they aren't contiguous, due to limitations caused by a similar implementation strategy as what was refactored here).
1 parent e594754 commit 46a5269

File tree

5 files changed

+200
-155
lines changed

5 files changed

+200
-155
lines changed

unit_tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../cmake")
4646
include(kc-test)
4747

4848
# --- Core API unit tests --- #
49-
kc_add_unit_test(test.core.p2p CORE NUM_PES 2 FILES test_main.cpp test_sendrecv.cpp)
49+
kc_add_unit_test(test.core.send-recv CORE NUM_PES 2 FILES test_main.cpp test_send_recv.cpp)
5050
kc_add_unit_test(test.core.broadcast CORE NUM_PES 2 FILES test_main.cpp test_broadcast.cpp)
5151
kc_add_unit_test(test.core.all-gather CORE NUM_PES 2 FILES test_main.cpp test_allgather.cpp)
5252
kc_add_unit_test(test.core.reduce CORE NUM_PES 2 FILES test_main.cpp test_reduce.cpp)

unit_tests/test_send_recv.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#include <gtest/gtest.h>
5+
#include <KokkosComm/KokkosComm.hpp>
6+
7+
#include "view_utils.hpp"
8+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
9+
#include "nccl/utils.hpp"
10+
#endif
11+
12+
namespace {
13+
14+
template <typename T>
15+
class SendRecv : public testing::Test {
16+
public:
17+
using Scalar = T;
18+
};
19+
20+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
21+
using ScalarTypes = ::testing::Types<float, double, int, int64_t>;
22+
#else
23+
using ScalarTypes =
24+
::testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
25+
#endif
26+
TYPED_TEST_SUITE(SendRecv, ScalarTypes);
27+
28+
template <KokkosComm::KokkosView View>
29+
void test_core_send_recv(const View& v) {
30+
using Exec = Kokkos::DefaultExecutionSpace;
31+
using Comm = KokkosComm::DefaultCommunicationSpace;
32+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
33+
auto nccl_ctx = test_utils::nccl::Ctx::init();
34+
auto raw_comm = nccl_ctx.comm();
35+
#else
36+
auto raw_comm = MPI_COMM_WORLD;
37+
#endif
38+
auto exec = Exec();
39+
auto comm = KokkosComm::Communicator<Comm, Exec>::duplicate(raw_comm, exec).value();
40+
const int size = comm.size();
41+
const int rank = comm.rank();
42+
if (size < 2) {
43+
GTEST_SKIP() << "Requires >= 2 ranks (" << size << " provided)";
44+
}
45+
const int src = 0;
46+
const int dst = 1;
47+
48+
if (rank == src) {
49+
test_utils::init_view(exec, v);
50+
exec.fence();
51+
KokkosComm::send(comm, v, dst).wait();
52+
} else if (rank == dst) {
53+
KokkosComm::recv(comm, v, src).wait();
54+
int errs = test_utils::count_errors(v);
55+
ASSERT_EQ(errs, 0);
56+
}
57+
}
58+
59+
TYPED_TEST(SendRecv, Contig1D) {
60+
auto v = test_utils::build_view<typename TestFixture::Scalar, 1>(test_utils::Contig{}, "v", 1013);
61+
test_core_send_recv(v);
62+
}
63+
TYPED_TEST(SendRecv, NonContig1D) {
64+
auto v = test_utils::build_view<typename TestFixture::Scalar, 1>(test_utils::NonContig{}, "v", 1013);
65+
test_core_send_recv(v);
66+
}
67+
TYPED_TEST(SendRecv, Contig2D) {
68+
auto v = test_utils::build_view<typename TestFixture::Scalar, 2>(test_utils::Contig{}, "v", 137, 17);
69+
test_core_send_recv(v);
70+
}
71+
TYPED_TEST(SendRecv, NonContig2D) {
72+
auto v = test_utils::build_view<typename TestFixture::Scalar, 2>(test_utils::NonContig{}, "v", 137, 17);
73+
test_core_send_recv(v);
74+
}
75+
76+
} // namespace

unit_tests/test_sendrecv.cpp

Lines changed: 0 additions & 121 deletions
This file was deleted.

unit_tests/view_builder.hpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

unit_tests/view_utils.hpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#pragma once
5+
6+
#include <cstddef>
7+
#include <string>
8+
#include <utility>
9+
10+
#include <Kokkos_Core.hpp>
11+
12+
namespace test_utils {
13+
14+
struct Contig {};
15+
struct NonContig {};
16+
17+
namespace Impl {
18+
19+
template <typename T, size_t N>
20+
struct AddPtrs {
21+
using type = typename AddPtrs<T*, N - 1>::type;
22+
};
23+
template <typename T>
24+
struct AddPtrs<T, 0> {
25+
using type = T;
26+
};
27+
template <typename T, size_t N>
28+
using Stars = typename AddPtrs<T, N>::type;
29+
30+
template <size_t Rank>
31+
auto all_then_one() {
32+
return [&]<size_t... Is>(std::index_sequence<Is...>) {
33+
return std::make_tuple((void(Is), Kokkos::ALL)..., 1);
34+
}(std::make_index_sequence<Rank>{});
35+
}
36+
template <typename ViewType, typename Tuple, size_t... Is>
37+
auto apply_subview(ViewType& v, Tuple&& t, std::index_sequence<Is...>) {
38+
return Kokkos::subview(v, std::get<Is>(std::forward<Tuple>(t))...);
39+
}
40+
41+
template <typename View, size_t... Is>
42+
struct InitFunctor {
43+
View v;
44+
std::array<int, sizeof...(Is)> exts;
45+
46+
KOKKOS_FUNCTION void operator()(decltype(static_cast<int>(std::declval<View>().extent(Is)))... idxs) const {
47+
int val = 0;
48+
int ids[] = {int(idxs)...};
49+
((val = ids[Is] + exts[Is] * val), ...);
50+
[&]<size_t... Js>(std::index_sequence<Js...>) { v(ids[Js]...) = val; }(std::index_sequence<Is...>{});
51+
}
52+
};
53+
54+
template <typename View, size_t... Is>
55+
struct CountFunctor {
56+
using Scalar = typename View::value_type;
57+
View v;
58+
std::array<int, sizeof...(Is)> exts;
59+
60+
KOKKOS_FUNCTION void operator()(
61+
decltype(static_cast<int>(std::declval<View>().extent(Is)))... idxs, int& lsum
62+
) const {
63+
int val = 0;
64+
int ids[] = {int(idxs)...};
65+
((val = ids[Is] + exts[Is] * val), ...);
66+
[&]<size_t... Js>(std::index_sequence<Js...>) {
67+
lsum += v(ids[Js]...) != Scalar(val);
68+
}(std::index_sequence<Is...>{});
69+
}
70+
};
71+
72+
} // namespace Impl
73+
74+
template <typename T, size_t R, typename... Extents>
75+
auto build_view(Contig, const std::string& name, Extents... exts) {
76+
static_assert(sizeof...(exts) == R, "Number of extents must match Rank");
77+
return Kokkos::View<Impl::Stars<T, R>>(name, exts...);
78+
}
79+
template <typename T, size_t R, typename... Extents>
80+
auto build_view(NonContig, const std::string& name, Extents... exts) {
81+
static_assert(sizeof...(exts) == R, "Number of extents must match Rank");
82+
Kokkos::View<Impl::Stars<T, R + 1>, Kokkos::LayoutRight> v(name, exts..., 2);
83+
auto args = Impl::all_then_one<R>();
84+
return Impl::apply_subview(v, args, std::make_index_sequence<R + 1>{});
85+
}
86+
87+
template <typename Exec, typename View>
88+
auto init_view(const Exec& exec, const View& v) -> void {
89+
constexpr size_t R = v.rank();
90+
if constexpr (R == 1) {
91+
Kokkos::parallel_for(Kokkos::RangePolicy<Exec>(exec, 0, v.extent(0)), KOKKOS_LAMBDA(const int i) { v(i) = i; });
92+
} else {
93+
[&]<size_t... Is>(std::index_sequence<Is...>) {
94+
std::array<int, R> exts = {static_cast<int>(v.extent(Is))...};
95+
Kokkos::parallel_for(
96+
Kokkos::MDRangePolicy<Exec, Kokkos::Rank<R>>(exec, {(void(Is), 0)...}, {static_cast<int>(v.extent(Is))...}),
97+
Impl::InitFunctor<View, Is...>{v, exts}
98+
);
99+
}(std::make_index_sequence<R>{});
100+
}
101+
exec.fence();
102+
}
103+
104+
template <typename View>
105+
auto count_errors(const View& v) -> int {
106+
constexpr size_t R = v.rank();
107+
int errs = 0;
108+
using Scalar = typename View::value_type;
109+
if constexpr (R == 1) {
110+
Kokkos::parallel_reduce(v.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += v(i) != Scalar(i); }, errs);
111+
} else {
112+
[&]<size_t... Is>(std::index_sequence<Is...>) {
113+
std::array<int, R> exts = {static_cast<int>(v.extent(Is))...};
114+
Kokkos::parallel_reduce(
115+
Kokkos::MDRangePolicy<Kokkos::Rank<R>>({(void(Is), 0)...}, {static_cast<int>(v.extent(Is))...}),
116+
Impl::CountFunctor<View, Is...>{v, exts}, errs
117+
);
118+
}(std::make_index_sequence<R>{});
119+
}
120+
return errs;
121+
}
122+
123+
} // namespace test_utils

0 commit comments

Comments
 (0)