diff --git a/CMakeLists.txt b/CMakeLists.txt index c044bb4e..bae7280f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,8 +6,7 @@ cmake_minimum_required(VERSION 3.23) project( KokkosComm - LANGUAGES - CXX + LANGUAGES CXX VERSION 0.2.0 DESCRIPTION "Experimental MPI interfaces (and more!) for the Kokkos C++ Performance Portability Programming ecosystem" HOMEPAGE_URL "https://kokkos.org/kokkos-comm/" @@ -21,12 +20,14 @@ option(KokkosComm_ENABLE_TESTS "Build KokkosComm tests" OFF) option(KokkosComm_ENABLE_MPI "Build KokkosComm with MPI transport" ON) option(KokkosComm_ENABLE_NCCL "Build KokkosComm with NCCL transport" OFF) option(KokkosComm_ABORT_ON_ERROR "Runtime error checks trigger a global abort" OFF) +option(KokkosComm_ENABLE_GPU_AWARE_MPI "Provided MPI is GPU-aware" OFF) # Resolve options set(KOKKOSCOMM_ENABLE_PERFTESTS ${KokkosComm_ENABLE_PERFTESTS} CACHE BOOL "" FORCE) set(KOKKOSCOMM_ENABLE_TESTS ${KokkosComm_ENABLE_TESTS} CACHE BOOL "" FORCE) set(KOKKOSCOMM_ENABLE_MPI ${KokkosComm_ENABLE_MPI} CACHE BOOL "" FORCE) set(KOKKOSCOMM_ENABLE_NCCL ${KokkosComm_ENABLE_NCCL} CACHE BOOL "" FORCE) +set(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI ${KokkosComm_ENABLE_GPU_AWARE_MPI} CACHE BOOL "" FORCE) find_package(Kokkos REQUIRED) if(KOKKOSCOMM_ENABLE_MPI) diff --git a/src/KokkosComm/CMakeLists.txt b/src/KokkosComm/CMakeLists.txt index c872251a..7b6b4df9 100644 --- a/src/KokkosComm/CMakeLists.txt +++ b/src/KokkosComm/CMakeLists.txt @@ -6,17 +6,17 @@ target_sources( KokkosComm INTERFACE FILE_SET kokkoscomm_public_headers - TYPE HEADERS - BASE_DIRS ${PROJECT_SOURCE_DIR}/src - FILES - KokkosComm.hpp - collective.hpp - concepts.hpp - fwd.hpp - point_to_point.hpp - traits.hpp - datatype.hpp - reduction_op.hpp + TYPE HEADERS + BASE_DIRS ${PROJECT_SOURCE_DIR}/src + FILES + KokkosComm.hpp + collective.hpp + concepts.hpp + fwd.hpp + point_to_point.hpp + traits.hpp + datatype.hpp + reduction_op.hpp ) # Implementation detail headers @@ -24,9 +24,9 @@ target_sources( KokkosComm INTERFACE FILE_SET kokkoscomm_impl_headers - TYPE HEADERS - BASE_DIRS ${PROJECT_SOURCE_DIR}/src - FILES impl/contiguous.hpp + TYPE HEADERS + BASE_DIRS ${PROJECT_SOURCE_DIR}/src + FILES impl/contiguous.hpp impl/host_staging.hpp ) # Configuration header @@ -34,9 +34,9 @@ target_sources( KokkosComm INTERFACE FILE_SET kokkoscomm_config_headers - TYPE HEADERS - BASE_DIRS ${CMAKE_BINARY_DIR}/src - FILES ${PROJECT_BINARY_DIR}/src/KokkosComm/config.hpp + TYPE HEADERS + BASE_DIRS ${CMAKE_BINARY_DIR}/src + FILES ${PROJECT_BINARY_DIR}/src/KokkosComm/config.hpp ) if(KOKKOSCOMM_ENABLE_MPI) @@ -45,30 +45,30 @@ if(KOKKOSCOMM_ENABLE_MPI) KokkosComm INTERFACE FILE_SET kokkoscomm_mpi_headers - TYPE HEADERS - BASE_DIRS ${PROJECT_SOURCE_DIR}/src - FILES - # Structures - mpi/mpi_space.hpp - mpi/comm_mode.hpp - mpi/handle.hpp - mpi/req.hpp - mpi/comm_mode.hpp - mpi/channel.hpp - # P2P - mpi/irecv.hpp - mpi/isend.hpp - mpi/recv.hpp - mpi/send.hpp - # Collectives - mpi/broadcast.hpp - mpi/allgather.hpp - mpi/allreduce.hpp - mpi/alltoall.hpp - mpi/reduce.hpp - mpi/scan.hpp - # Other/utilities - mpi/barrier.hpp + TYPE HEADERS + BASE_DIRS ${PROJECT_SOURCE_DIR}/src + FILES + # Structures + mpi/mpi_space.hpp + mpi/comm_mode.hpp + mpi/handle.hpp + mpi/req.hpp + mpi/comm_mode.hpp + mpi/channel.hpp + # P2P + mpi/irecv.hpp + mpi/isend.hpp + mpi/recv.hpp + mpi/send.hpp + # Collectives + mpi/broadcast.hpp + mpi/allgather.hpp + mpi/allreduce.hpp + mpi/alltoall.hpp + mpi/reduce.hpp + mpi/scan.hpp + # Other/utilities + mpi/barrier.hpp ) # Implementation detail MPI headers @@ -76,9 +76,9 @@ if(KOKKOSCOMM_ENABLE_MPI) KokkosComm INTERFACE FILE_SET kokkoscomm_mpi_impl_headers - TYPE HEADERS - BASE_DIRS ${PROJECT_SOURCE_DIR}/src - FILES mpi/impl/pack_traits.hpp mpi/impl/packer.hpp mpi/impl/tags.hpp mpi/impl/error_handling.hpp + TYPE HEADERS + BASE_DIRS ${PROJECT_SOURCE_DIR}/src + FILES mpi/impl/pack_traits.hpp mpi/impl/packer.hpp mpi/impl/tags.hpp mpi/impl/error_handling.hpp ) endif() @@ -88,22 +88,22 @@ if(KOKKOSCOMM_ENABLE_NCCL) KokkosComm INTERFACE FILE_SET kokkoscomm_nccl_headers - TYPE HEADERS - BASE_DIRS ${PROJECT_SOURCE_DIR}/src - FILES - # Structures - nccl/nccl_space.hpp - nccl/handle.hpp - nccl/req.hpp - # P2P - nccl/send.hpp - nccl/recv.hpp - # Collectives - nccl/broadcast.hpp - nccl/alltoall.hpp - nccl/allgather.hpp - nccl/allreduce.hpp - nccl/reduce.hpp + TYPE HEADERS + BASE_DIRS ${PROJECT_SOURCE_DIR}/src + FILES + # Structures + nccl/nccl_space.hpp + nccl/handle.hpp + nccl/req.hpp + # P2P + nccl/send.hpp + nccl/recv.hpp + # Collectives + nccl/broadcast.hpp + nccl/alltoall.hpp + nccl/allgather.hpp + nccl/allreduce.hpp + nccl/reduce.hpp ) # Implementation detail NCCL headers @@ -111,9 +111,9 @@ if(KOKKOSCOMM_ENABLE_NCCL) KokkosComm INTERFACE FILE_SET kokkoscomm_nccl_impl_headers - TYPE HEADERS - BASE_DIRS ${PROJECT_SOURCE_DIR}/src - FILES nccl/impl/pack_traits.hpp nccl/impl/packer.hpp nccl/impl/types.hpp + TYPE HEADERS + BASE_DIRS ${PROJECT_SOURCE_DIR}/src + FILES nccl/impl/pack_traits.hpp nccl/impl/packer.hpp nccl/impl/types.hpp ) endif() @@ -127,10 +127,7 @@ include(CheckCXXCompilerFlag) macro(kokkoscomm_check_and_add_compile_options) set(target ${ARGV0}) set(flag ${ARGV1}) - check_cxx_compiler_flag( - ${flag} - HAS_${flag} - ) + check_cxx_compiler_flag(${flag} HAS_${flag}) if(HAS_${flag}) target_compile_options(${target} INTERFACE ${flag}) endif() @@ -140,12 +137,7 @@ endmacro() add_library(KokkosCommFlags INTERFACE) add_library(KokkosComm::KokkosCommFlags ALIAS KokkosCommFlags) target_compile_features(KokkosCommFlags INTERFACE cxx_std_20) -set_target_properties( - KokkosCommFlags - PROPERTIES - CXX_EXTENSIONS - OFF -) +set_target_properties(KokkosCommFlags PROPERTIES CXX_EXTENSIONS OFF) kokkoscomm_check_and_add_compile_options(KokkosCommFlags -Wall) kokkoscomm_check_and_add_compile_options(KokkosCommFlags -Wextra) @@ -158,43 +150,24 @@ kokkoscomm_check_and_add_compile_options(KokkosCommFlags -Wmissing-include-dirs) kokkoscomm_check_and_add_compile_options(KokkosCommFlags -Wno-gnu-zero-variadic-macro-arguments) # Linking -target_link_libraries( - KokkosComm - INTERFACE - KokkosComm::KokkosCommFlags - Kokkos::kokkos -) +target_link_libraries(KokkosComm INTERFACE KokkosComm::KokkosCommFlags Kokkos::kokkos) if(KOKKOSCOMM_ENABLE_MPI) target_link_libraries(KokkosComm INTERFACE MPI::MPI_CXX) endif() if(KOKKOSCOMM_ENABLE_NCCL) target_link_libraries(KokkosComm INTERFACE NCCL::NCCL) endif() -target_link_libraries( - KokkosComm - INTERFACE - KokkosComm::KokkosCommFlags - Kokkos::kokkos -) +target_link_libraries(KokkosComm INTERFACE KokkosComm::KokkosCommFlags Kokkos::kokkos) # Install library install( - TARGETS - KokkosComm - KokkosCommFlags + TARGETS KokkosComm KokkosCommFlags EXPORT KokkosCommTargets - FILE_SET - kokkoscomm_public_headers - FILE_SET - kokkoscomm_impl_headers - FILE_SET - kokkoscomm_mpi_headers - FILE_SET - kokkoscomm_mpi_impl_headers - FILE_SET - kokkoscomm_nccl_headers - FILE_SET - kokkoscomm_nccl_impl_headers - FILE_SET - kokkoscomm_config_headers + FILE_SET kokkoscomm_public_headers + FILE_SET kokkoscomm_impl_headers + FILE_SET kokkoscomm_mpi_headers + FILE_SET kokkoscomm_mpi_impl_headers + FILE_SET kokkoscomm_nccl_headers + FILE_SET kokkoscomm_nccl_impl_headers + FILE_SET kokkoscomm_config_headers ) diff --git a/src/KokkosComm/impl/host_staging.hpp b/src/KokkosComm/impl/host_staging.hpp new file mode 100644 index 00000000..70e9a680 --- /dev/null +++ b/src/KokkosComm/impl/host_staging.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project + +#pragma once + +#include + +#include + +namespace KokkosComm::Impl { + +template +inline constexpr bool needs_staging_v = + !Kokkos::SpaceAccessibility::accessible; + +/// Stage view on the host for non-GPU-aware communications. +/// No-op if `view` is device-accessible from the host. +template +auto stage_for(const V& view) { + return Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, view); +} + +/// Copy back to device (e.g. for receive operations). +/// No-op if `dst` is device-accessible from the host. +template +auto copy_back(const E& space, Dst& dst, const Src& src) -> void { + if constexpr (needs_staging_v) { + Kokkos::deep_copy(space, dst, src); + } +} + +} // namespace KokkosComm::Impl diff --git a/src/KokkosComm/mpi/allgather.hpp b/src/KokkosComm/mpi/allgather.hpp index f2d54887..e340ed16 100644 --- a/src/KokkosComm/mpi/allgather.hpp +++ b/src/KokkosComm/mpi/allgather.hpp @@ -12,84 +12,111 @@ #include "mpi_space.hpp" #include "req.hpp" +#include #include "impl/error_handling.hpp" namespace KokkosComm { namespace mpi { -template -auto iallgather(const ExecSpace &space, const SView sv, RView rv, MPI_Comm comm) -> Req { - using ST = typename SView::non_const_value_type; - using RT = typename RView::non_const_value_type; +template +auto iallgather(const E& space, const SV& sv, RV& rv, MPI_Comm comm) -> Req { + using ST = typename SV::non_const_value_type; + using RT = typename RV::non_const_value_type; static_assert(std::is_same_v, "KokkosComm::mpi::iallgather: View value types must be identical"); Kokkos::Tools::pushRegion("KokkosComm::mpi::iallgather"); - fail_if(!is_contiguous(sv) || !is_contiguous(rv), "KokkosComm::mpi::iallgather: unimplemented for non-contiguous views"); - // Sync: Work in space may have been used to produce view data. - space.fence("fence before non-blocking all-gather"); + fail_if(span(sv) == span(rv), "KokkosComm::mpi::iallgather: all ranks must send & receive the same count"); + const int cnt = span(sv); Req req; - // All ranks send/recv same count - MPI_Iallgather(data_handle(sv), span(sv), datatype, data_handle(rv), span(sv), datatype, - comm, &req.mpi_request()); +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce view data. + space.fence("fence before GPU-aware `MPI_Iallgather`"); + MPI_Iallgather(data_handle(sv), cnt, datatype, data_handle(rv), cnt, datatype, comm, + &req.mpi_request()); req.extend_view_lifetime(sv); req.extend_view_lifetime(rv); +#else + auto host_sv = KokkosComm::Impl::stage_for(sv); + auto host_rv = KokkosComm::Impl::stage_for(rv); + space.fence("fence host staging before `MPI_Iallgather`"); + MPI_Iallgather(data_handle(host_sv), cnt, datatype, data_handle(host_rv), cnt, datatype, + comm, &req.mpi_request()); + // Implicitly extends lifetimes of `host_rv` and `rv` due to lambda capture + req.call_after_mpi_wait([=]() { + KokkosComm::Impl::copy_back(space, rv, host_rv); + space.fence("fence copy back after `MPI_Iallgather`"); + }); + req.extend_view_lifetime(host_sv); + req.extend_view_lifetime(sv); +#endif Kokkos::Tools::popRegion(); return req; } -template -void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) { - Kokkos::Tools::pushRegion("KokkosComm::Mpi::allgather"); - - using SendScalar = typename SendView::value_type; - using RecvScalar = typename RecvView::value_type; - - static_assert(KokkosComm::rank() <= 1, "allgather for SendView::rank > 1 not supported"); - static_assert(KokkosComm::rank() <= 1, "allgather for RecvView::rank > 1 not supported"); - - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(sv), "low-level allgather requires contiguous send view"); - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(rv), "low-level allgather requires contiguous recv view"); - - const int count = KokkosComm::span(sv); // all ranks send/recv same count - MPI_Allgather(KokkosComm::data_handle(sv), count, datatype(), KokkosComm::data_handle(rv), - count, datatype(), comm); +template +void allgather(const ExecSpace& space, const SendView& sv, const RecvView& rv, MPI_Comm comm) { + using ST = typename SendView::non_const_value_type; + using RT = typename RecvView::non_const_value_type; + Kokkos::Tools::pushRegion("KokkosComm::mpi::allgather"); + static_assert(std::is_same_v, "KokkosComm::mpi::allgather: View value types must be identical"); + fail_if(!is_contiguous(sv) || !is_contiguous(rv), + "KokkosComm::mpi::allgather: unimplemented for non-contiguous Views"); + + fail_if(span(sv) == span(rv), "KokkosComm::mpi::allgather: all ranks must send & receive the same count"); + const int cnt = span(sv); + +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce send view data + space.fence("fence before GPU-aware `MPI_Allgather`"); + MPI_Allgather(data_handle(sv), cnt, datatype(), data_handle(rv), cnt, datatype(), comm); +#else + auto host_sv = KokkosComm::Impl::stage_for(sv); + auto host_rv = KokkosComm::Impl::stage_for(rv); + space.fence("fence host staging before `MPI_Allgather`"); + MPI_Allgather(data_handle(host_sv), cnt, datatype(), data_handle(host_rv), cnt, + datatype(), comm); + KokkosComm::Impl::copy_back(space, rv, host_rv); + space.fence("fence copy back after `MPI_Allgather`"); +#endif Kokkos::Tools::popRegion(); } -// in-place allgather -template -void allgather(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) { - Kokkos::Tools::pushRegion("KokkosComm::Mpi::allgather"); - - using RecvScalar = typename RecvView::value_type; - - static_assert(KokkosComm::rank() <= 1, "allgather for RecvView::rank > 1 not supported"); - - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(rv), "low-level allgather requires contiguous recv view"); +template +void allgather(const SendView& sv, const RecvView& rv, MPI_Comm comm) { + allgather(Kokkos::DefaultExecutionSpace{}, sv, rv, comm); +} - space.fence("fence before allgather"); // work in space may have been used to produce send view data - MPI_Allgather(MPI_IN_PLACE, 0 /*ignored*/, MPI_DATATYPE_NULL /*ignored*/, KokkosComm::data_handle(rv), recvCount, - datatype(), comm); +// In-place allgather +template +void allgather(const ExecSpace& space, View& v, size_t cnt, MPI_Comm comm) { + using T = typename View::non_const_value_type; + Kokkos::Tools::pushRegion("KokkosComm::mpi::allgather"); + fail_if(!is_contiguous(v), "KokkosComm::mpi::allgather: unimplemented for non-contiguous view"); + +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce send view data + space.fence("fence before GPU-aware in-place `MPI_Allgather`"); + MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, data_handle(v), cnt, datatype(), comm); +#else + auto host_v = KokkosComm::Impl::stage_for(v); + space.fence("fence host staging before in-place `MPI_Allgather`"); + MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, data_handle(host_v), cnt, datatype(), comm); + KokkosComm::Impl::copy_back(space, v, host_v); + space.fence("fence copy back after in-place `MPI_Allgather`"); +#endif Kokkos::Tools::popRegion(); } -template -void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, MPI_Comm comm) { - Kokkos::Tools::pushRegion("KokkosComm::Mpi::allgather"); - - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(sv) || !KokkosComm::is_contiguous(rv), - "allgather for non-contiguous views not implemented"); - - space.fence("fence before allgather"); // work in space may have been used to produce send view data - allgather(sv, rv, comm); - - Kokkos::Tools::popRegion(); +// In-place allgather +template +void allgather(View& v, size_t cnt, MPI_Comm comm) { + allgather(Kokkos::DefaultExecutionSpace{}, v, cnt, comm); } } // namespace mpi @@ -97,7 +124,7 @@ namespace Experimental::Impl { template struct AllGather { - static auto execute(Handle &h, const SendView sv, RecvView rv) -> Req { + static auto execute(Handle& h, const SendView sv, RecvView rv) -> Req { return mpi::iallgather(h.space(), sv, rv, h.comm()); } }; diff --git a/src/KokkosComm/mpi/alltoall.hpp b/src/KokkosComm/mpi/alltoall.hpp index fcbcedc8..c80bf07d 100644 --- a/src/KokkosComm/mpi/alltoall.hpp +++ b/src/KokkosComm/mpi/alltoall.hpp @@ -12,6 +12,7 @@ #include "mpi_space.hpp" #include "req.hpp" +#include #include "impl/pack_traits.hpp" #include "impl/error_handling.hpp" @@ -19,94 +20,104 @@ namespace KokkosComm { namespace mpi { template -auto ialltoall(const ExecSpace &space, const SView sv, RView rv, int count, MPI_Comm comm) -> Req { +auto ialltoall(const ExecSpace& space, const SView sv, RView rv, int count, MPI_Comm comm) -> Req { using ST = typename SView::non_const_value_type; using RT = typename RView::non_const_value_type; static_assert(std::is_same_v, "KokkosComm::mpi::ialltoall: View value types must be identical"); Kokkos::Tools::pushRegion("KokkosComm::mpi::ialltoall"); - fail_if(!is_contiguous(sv) || !is_contiguous(rv), "KokkosComm::mpi::ialltoall: unimplemented for non-contiguous views"); + Req req; +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) // Sync: Work in space may have been used to produce view data. space.fence("fence before non-blocking all-gather"); - - Req req; // All ranks send/recv same count MPI_Ialltoall(data_handle(sv), count, datatype(), data_handle(rv), count, datatype(), comm, &req.mpi_request()); req.extend_view_lifetime(sv); req.extend_view_lifetime(rv); +#else + auto host_sv = KokkosComm::Impl::stage_for(sv); + auto host_rv = KokkosComm::Impl::stage_for(rv); + space.fence("fence host staging before `MPI_Ialltoall`"); + MPI_Ialltoall(data_handle(host_sv), count, datatype(), data_handle(host_rv), count, + datatype(), comm, &req.mpi_request()); + // Implicitly extends lifetimes of `host_rv` and `rv` due to lambda capture + req.call_after_mpi_wait([=]() { + KokkosComm::Impl::copy_back(space, rv, host_rv); + space.fence("fence copy back after `MPI_Ialltoall`"); + }); + req.extend_view_lifetime(host_sv); + req.extend_view_lifetime(sv); +#endif Kokkos::Tools::popRegion(); return req; } template -void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount, const RecvView &rv, - const size_t recvCount, MPI_Comm comm) { +void alltoall(const ExecSpace& space, const SendView& sv, size_t s_cnt, RecvView& rv, size_t r_cnt, MPI_Comm comm) { + using ST = typename SendView::non_const_value_type; + using RT = typename RecvView::non_const_value_type; + static_assert(std::is_same_v, "KokkosComm::mpi::alltoall: View value types must be identical"); Kokkos::Tools::pushRegion("KokkosComm::mpi::alltoall"); - - using SendScalar = typename SendView::value_type; - using RecvScalar = typename RecvView::value_type; - - static_assert(KokkosComm::rank() <= 1, "alltoall for SendView::rank > 1 not supported"); - static_assert(KokkosComm::rank() <= 1, "alltoall for RecvView::rank > 1 not supported"); - - // Make sure views are ready - space.fence("KokkosComm::mpi::alltoall"); - - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(sv) || !KokkosComm::is_contiguous(rv), - "alltoall for non-contiguous views not implemented"); - - int size; - MPI_Comm_size(comm, &size); - - if (sendCount * size > KokkosComm::extent(sv, 0)) { - std::stringstream ss; - ss << "alltoall sendCount * communicator size (" << sendCount << " * " << size - << ") is greater than send view size"; - KokkosComm::mpi::fail_if(true, ss.str().data()); - } - if (recvCount * size > KokkosComm::extent(rv, 0)) { - std::stringstream ss; - ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size - << ") is greater than recv view size"; - KokkosComm::mpi::fail_if(true, ss.str().data()); - } - - MPI_Alltoall(KokkosComm::data_handle(sv), sendCount, datatype(), KokkosComm::data_handle(rv), - recvCount, datatype(), comm); + fail_if(!is_contiguous(sv) || !is_contiguous(rv), + "KokkosComm::mpi::alltoall: unimplemented for non-contiguous views"); + + const int size = [&]() { + int tmp; + MPI_Comm_size(comm, &tmp); + return tmp; + }(); + fail_if(s_cnt * size > extent(sv, 0), + "KokkosComm::mpi::alltoall: send count * comm size is greater than send view size"); + fail_if(r_cnt * size > extent(rv, 0), + "KokkosComm::mpi::alltoall: receive count * comm size is greater than receive view size"); + +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce view data. + space.fence("fence before GPU-aware `MPI_Alltoall`"); + MPI_Alltoall(data_handle(sv), s_cnt, datatype(), data_handle(rv), r_cnt, datatype(), + comm); +#else + auto host_sv = KokkosComm::Impl::stage_for(sv); + auto host_rv = KokkosComm::Impl::stage_for(rv); + space.fence("fence host staging before `MPI_Alltoall`"); + MPI_Alltoall(data_handle(host_sv), s_cnt, datatype(), data_handle(host_rv), r_cnt, + datatype(), comm); + KokkosComm::Impl::copy_back(space, rv, host_rv); + space.fence("fence copy back after `MPI_Alltoall`"); +#endif Kokkos::Tools::popRegion(); } -// in-place alltoall -template -void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) { +// In-place alltoall +template +void alltoall(const ExecSpace& space, View& v, size_t cnt, MPI_Comm comm) { + using T = typename View::non_const_value_type; Kokkos::Tools::pushRegion("KokkosComm::mpi::alltoall"); + fail_if(!is_contiguous(v), "KokkosComm::mpi::alltoall: unimplemented for non-contiguous views"); - using RecvScalar = typename RecvView::value_type; - - static_assert(RecvView::rank <= 1, "alltoall for RecvView::rank > 1 not supported"); - - // Make sure views are ready - space.fence("KokkosComm::mpi::alltoall"); + const int size = [&]() { + int tmp; + MPI_Comm_size(comm, &tmp); + return tmp; + }(); + fail_if(cnt * size > extent(v, 0), "KokkosComm::mpi::alltoall: count * comm size is greater than view size"); - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(rv), "alltoall for non-contiguous views not implemented"); - - int size; - MPI_Comm_size(comm, &size); - - if (recvCount * size > KokkosComm::extent(rv, 0)) { - std::stringstream ss; - ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size - << ") is greater than recv view size"; - KokkosComm::mpi::fail_if(true, ss.str().data()); - } - - MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, KokkosComm::data_handle(rv), recvCount, - datatype(), comm); +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce view data. + space.fence("fence before GPU-aware in-place `MPI_Alltoall`"); + MPI_Alltoall(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, data_handle(v), cnt, datatype(), comm); +#else + auto host_v = KokkosComm::Impl::stage_for(v); + space.fence("fence host staging before in-place `MPI_Alltoall`"); + MPI_Alltoall(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, data_handle(host_v), cnt, datatype(), comm); + KokkosComm::Impl::copy_back(space, v, host_v); + space.fence("fence copy back after in-place `MPI_Alltoall`"); +#endif Kokkos::Tools::popRegion(); } @@ -116,7 +127,7 @@ namespace Experimental::Impl { template struct AllToAll { - static auto execute(Handle &h, const SendView sv, RecvView rv, int count) -> Req { + static auto execute(Handle& h, const SendView sv, RecvView rv, int count) -> Req { return mpi::ialltoall(h.space(), sv, rv, count, h.mpi_comm()); } }; diff --git a/src/KokkosComm/mpi/broadcast.hpp b/src/KokkosComm/mpi/broadcast.hpp index 6ce625aa..ed515cf2 100644 --- a/src/KokkosComm/mpi/broadcast.hpp +++ b/src/KokkosComm/mpi/broadcast.hpp @@ -12,6 +12,7 @@ #include "mpi_space.hpp" #include "req.hpp" +#include #include "impl/error_handling.hpp" namespace KokkosComm { @@ -23,40 +24,54 @@ auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req Kokkos::Tools::pushRegion("KokkosComm::mpi::ibroadcast"); fail_if(!is_contiguous(v), "KokkosComm::mpi::ibroadcast: unimplemented for non-contiguous views"); - // Sync: Work in space may have been used to produce view data. - space.fence("fence before non-blocking broadcast"); - Req req; - MPI_Ibcast(data_handle(v), span(v), datatype, root, comm, &req.mpi_request()); +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce view data. + space.fence("fence before GPU-aware `MPI_Ibcast`"); + MPI_Ibcast(data_handle(v), span(v), datatype(), root, comm, &req.mpi_request()); req.extend_view_lifetime(v); +#else + auto host_v = KokkosComm::Impl::stage_for(v); + // Sync: Ensure that `host_v` is done being copied on the host + space.fence("fence host staging before `MPI_Ibcast`"); + MPI_Ibcast(data_handle(host_v), span(host_v), datatype(), root, comm, &req.mpi_request()); + // Implicitly extends lifetimes of `host_v` and `v` due to lambda capture + req.call_after_mpi_wait([=]() { + KokkosComm::Impl::copy_back(space, v, host_v); + space.fence("fence copy back after `MPI_Ibcast`"); + }); +#endif Kokkos::Tools::popRegion(); return req; } -template -void broadcast(View const& v, int root, MPI_Comm comm) { - Kokkos::Tools::pushRegion("KokkosComm::mpi::broadcast"); - - using Scalar = typename View::value_type; - - KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(v), "low-level broadcast requires contiguous view"); - - MPI_Bcast(KokkosComm::data_handle(v), KokkosComm::span(v), datatype(), root, comm); - - Kokkos::Tools::popRegion(); -} - template void broadcast(ExecSpace const& space, View const& v, int root, MPI_Comm comm) { + using T = typename View::non_const_value_type; Kokkos::Tools::pushRegion("KokkosComm::mpi::broadcast"); - - space.fence("fence before broadcast"); // work in space may have been used to produce view data - broadcast(v, root, comm); + fail_if(!is_contiguous(v), "KokkosComm::mpi::broadcast: unimplemented for non-contiguous views"); + +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + // Sync: Work in space may have been used to produce view data + space.fence("fence before GPU-aware `MPI_Bcast`"); + MPI_Bcast(data_handle(v), span(v), datatype(), root, comm); +#else + auto host_v = KokkosComm::Impl::stage_for(v); + space.fence("fence host staging before `MPI_Bcast`"); + MPI_Bcast(data_handle(host_v), span(host_v), datatype(), root, comm); + KokkosComm::Impl::copy_back(space, v, host_v); + space.fence("fence copy back after `MPI_Bcast`"); +#endif Kokkos::Tools::popRegion(); } +template +void broadcast(View const& v, int root, MPI_Comm comm) { + broadcast(Kokkos::DefaultExecutionSpace{}, v, root, comm); +} + } // namespace mpi namespace Experimental::Impl { diff --git a/src/KokkosComm/mpi/irecv.hpp b/src/KokkosComm/mpi/irecv.hpp index 5c56c345..f9a1479a 100644 --- a/src/KokkosComm/mpi/irecv.hpp +++ b/src/KokkosComm/mpi/irecv.hpp @@ -9,6 +9,7 @@ #include "mpi_space.hpp" #include "handle.hpp" +#include #include "impl/pack_traits.hpp" #include "impl/tags.hpp" #include "impl/error_handling.hpp" @@ -20,25 +21,43 @@ namespace Impl { template struct Recv { static Req execute(Handle &h, const RecvView &rv, int src) { - using KCPT = KokkosComm::PackTraits; - using Packer = typename KCPT::packer_type; - using Args = typename Packer::args_type; + using T = typename RecvView::non_const_value_type; + using Packer = typename KokkosComm::PackTraits::packer_type; const ExecSpace &space = h.space(); Req req; - if (KokkosComm::is_contiguous(rv)) { - space.fence("fence before irecv"); - MPI_Irecv(KokkosComm::data_handle(rv), KokkosComm::span(rv), datatype(), - src, POINTTOPOINT_TAG, h.mpi_comm(), &req.mpi_request()); - req.extend_view_lifetime(rv); +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + if (is_contiguous(rv)) { + space.fence("fence before GPU-aware `MPI_Irecv`"); + MPI_Irecv(data_handle(rv), span(rv), datatype(), src, POINTTOPOINT_TAG, h.mpi_comm(), + &req.mpi_request()); } else { - Args args = Packer::allocate_packed_for(space, "TODO", rv); - space.fence("fence before irecv"); + auto args = Packer::allocate_packed_for(space, "TODO", rv); + space.fence("fence packing before GPU-aware `MPI_Irecv`"); MPI_Irecv(args.view.data(), args.count, args.datatype, src, POINTTOPOINT_TAG, h.mpi_comm(), &req.mpi_request()); - // implicitly extends args.view and rv lifetime due to lambda capture + // Implicitly extends args.view and rv lifetime due to lambda capture req.call_after_mpi_wait([=]() { Packer::unpack_into(space, rv, args.view); }); } + req.extend_view_lifetime(rv); +#else + auto host_rv = KokkosComm::Impl::stage_for(rv); + space.fence("fence host staging before `MPI_Irecv`"); + if (is_contiguous(host_rv)) { + MPI_Irecv(data_handle(host_rv), span(host_rv), datatype(), src, POINTTOPOINT_TAG, h.mpi_comm(), + &req.mpi_request()); + req.extend_view_lifetime(host_rv); + } else { + auto args = Packer::allocate_packed_for(space, "packed `MPI_Irecv`", host_rv); + space.fence("fence packing before `MPI_Irecv`"); + MPI_Irecv(args.view.data(), args.count, args.datatype, src, POINTTOPOINT_TAG, h.mpi_comm(), &req.mpi_request()); + // Implicitly extends `args.view`, `host_rv` and `rv` lifetimes due to lambda capture + // TODO: Can we unpack directly into `rv` instead of `host_rv`? + req.call_after_mpi_wait([=]() { Packer::unpack_into(space, rv, args.view); }); + req.call_after_mpi_wait([=]() { KokkosComm::Impl::copy_back(space, rv, host_rv); }); + } + req.extend_view_lifetime(rv); +#endif return req; } }; diff --git a/src/KokkosComm/mpi/isend.hpp b/src/KokkosComm/mpi/isend.hpp index f92fc248..e519261a 100644 --- a/src/KokkosComm/mpi/isend.hpp +++ b/src/KokkosComm/mpi/isend.hpp @@ -8,10 +8,12 @@ #include #include #include +#include "KokkosComm/impl/host_staging.hpp" #include "mpi_space.hpp" #include "comm_mode.hpp" #include "handle.hpp" +#include #include "impl/pack_traits.hpp" #include "impl/tags.hpp" #include "impl/error_handling.hpp" @@ -20,9 +22,11 @@ namespace KokkosComm { namespace Impl { template -Req isend_impl(Handle &h, const SendView &sv, int dest, int tag, SendMode) { - auto mpi_isend_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, - MPI_Comm mpi_comm, MPI_Request *mpi_req) { +Req isend_impl(Handle& h, const SendView& sv, int dest, int tag, SendMode) { + using T = typename SendView::non_const_value_type; + + auto mpi_isend_fn = [](void* mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, + MPI_Comm mpi_comm, MPI_Request* mpi_req) { if constexpr (std::is_same_v) { MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req); } else if constexpr (std::is_same_v) { @@ -35,28 +39,42 @@ Req isend_impl(Handle &h, const SendView &sv, int }; Req req; - if (KokkosComm::is_contiguous(sv)) { - h.space().fence("fence before isend"); - mpi_isend_fn(KokkosComm::data_handle(sv), KokkosComm::span(sv), datatype(), - dest, tag, h.mpi_comm(), &req.mpi_request()); - req.extend_view_lifetime(sv); +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + if (is_contiguous(sv)) { + h.space().fence("fence before GPU-aware `MPI_Isend`"); + mpi_isend_fn(data_handle(sv), span(sv), datatype(), dest, tag, h.mpi_comm(), &req.mpi_request()); } else { using Packer = typename KokkosComm::PackTraits::packer_type; - using Args = typename Packer::args_type; - - Args args = Packer::pack(h.space(), sv); - h.space().fence("fence before isend"); - mpi_isend_fn(args.view.data(), args.count, args.datatype, dest, tag, h.mpi_comm(), &req.mpi_request()); + auto args = Packer::pack(h.space(), sv); + h.space().fence("fence packing before GPU-aware `MPI_Isend`"); + mpi_isend_fn(data_handle(args.view), args.count, args.datatype, dest, tag, h.mpi_comm(), &req.mpi_request()); + req.extend_view_lifetime(args.view); + } + req.extend_view_lifetime(sv); +#else + auto host_sv = KokkosComm::Impl::stage_for(sv); + h.space().fence("fence host staging before `MPI_Isend`"); + if (is_contiguous(host_sv)) { + mpi_isend_fn(data_handle(host_sv), span(host_sv), datatype(), dest, tag, h.mpi_comm(), + &req.mpi_request()); + } else { + using Packer = typename KokkosComm::PackTraits::packer_type; + auto args = Packer::pack(h.space(), host_sv); + h.space().fence("fence packing before `MPI_Isend`"); + mpi_isend_fn(data_handle(args.view), args.count, args.datatype, dest, tag, h.mpi_comm(), &req.mpi_request()); req.extend_view_lifetime(args.view); - req.extend_view_lifetime(sv); } + req.extend_view_lifetime(host_sv); + // TODO: Do we need to extend the lifetime of `sv` if we are staging it on the host? + req.extend_view_lifetime(sv); +#endif return req; } // Implementation of KokkosComm::Send template struct Send { - static Req execute(Handle &h, const SendView &sv, int dest) { + static Req execute(Handle& h, const SendView& sv, int dest) { return isend_impl(h, sv, dest, POINTTOPOINT_TAG, mpi::DefaultCommMode{}); } }; @@ -65,17 +83,17 @@ struct Send { namespace mpi { template -Req isend(Handle &h, const SendView &sv, int dest, int tag, SendMode) { +Req isend(Handle& h, const SendView& sv, int dest, int tag, SendMode) { return KokkosComm::Impl::isend_impl(h, sv, dest, tag, SendMode{}); } template -Req isend(Handle &h, const SendView &sv, int dest, int tag) { +Req isend(Handle& h, const SendView& sv, int dest, int tag) { return isend(h, sv, dest, tag, DefaultCommMode{}); } template -void isend(const SendView &sv, int dest, int tag, MPI_Comm comm, MPI_Request &req) { +void isend(const SendView& sv, int dest, int tag, MPI_Comm comm, MPI_Request& req) { Kokkos::Tools::pushRegion("KokkosComm::Impl::isend"); KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(sv), "only contiguous views supported for low-level isend"); diff --git a/src/KokkosComm/mpi/recv.hpp b/src/KokkosComm/mpi/recv.hpp index 107a52aa..a7e1cfcc 100644 --- a/src/KokkosComm/mpi/recv.hpp +++ b/src/KokkosComm/mpi/recv.hpp @@ -10,6 +10,7 @@ #include #include +#include #include "impl/pack_traits.hpp" #include "impl/error_handling.hpp" @@ -30,22 +31,34 @@ void recv(const RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Status *statu template void recv(const ExecSpace &space, RecvView &rv, int src, int tag, MPI_Comm comm) { Kokkos::Tools::pushRegion("KokkosComm::mpi::recv"); + using T = typename RecvView::non_const_value_type; + using Packer = typename PackTraits::packer_type; - using KCPT = KokkosComm::PackTraits; - using Packer = typename KCPT::packer_type; - using Args = typename Packer::args_type; - - if (!KokkosComm::is_contiguous(rv)) { - Args args = Packer::allocate_packed_for(space, "packed", rv); - space.fence("Fence after allocation before MPI_Recv"); - MPI_Recv(KokkosComm::data_handle(args.view), args.count, args.datatype, src, tag, comm, MPI_STATUS_IGNORE); +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) + if (is_contiguous(rv)) { + space.fence("fence before GPU-aware `MPI_Recv`"); // prevent work in `space` from writing to recv buffer + MPI_Recv(data_handle(rv), span(rv), datatype(), src, tag, comm, MPI_STATUS_IGNORE); + } else { + auto args = Packer::allocate_packed_for(space, "packed `MPI_Recv`", rv); + space.fence("fence packing before GPU-aware `MPI_Recv`"); + MPI_Recv(data_handle(args.view), args.count, args.datatype, src, tag, comm, MPI_STATUS_IGNORE); Packer::unpack_into(space, rv, args.view); + } +#else + auto host_rv = KokkosComm::Impl::stage_for(rv); + space.fence("fence host staging before `MPI_Recv`"); + if (is_contiguous(host_rv)) { + MPI_Recv(data_handle(host_rv), span(host_rv), datatype(), src, tag, comm, MPI_STATUS_IGNORE); } else { - using RecvScalar = typename RecvView::value_type; - space.fence("Fence before MPI_Recv"); // prevent work in `space` from writing to recv buffer - MPI_Recv(KokkosComm::data_handle(rv), KokkosComm::span(rv), datatype(), src, tag, comm, - MPI_STATUS_IGNORE); + auto args = Packer::allocate_packed_for(space, "packed `MPI_Recv`", host_rv); + space.fence("fence packing before `MPI_Recv`"); + MPI_Recv(data_handle(args.view), args.count, args.datatype, src, tag, comm, MPI_STATUS_IGNORE); + // TODO: Can we unpack directly into `rv` instead of `host_rv`? + Packer::unpack_into(space, host_rv, args.view); + KokkosComm::Impl::copy_back(space, rv, host_rv); + space.fence("fence copy back after `MPI_Recv`"); } +#endif Kokkos::Tools::popRegion(); } diff --git a/src/KokkosComm/mpi/send.hpp b/src/KokkosComm/mpi/send.hpp index 9bfbcb3f..40fa98b0 100644 --- a/src/KokkosComm/mpi/send.hpp +++ b/src/KokkosComm/mpi/send.hpp @@ -11,8 +11,8 @@ #include #include "comm_mode.hpp" +#include #include "impl/pack_traits.hpp" -#include "impl/error_handling.hpp" namespace KokkosComm::mpi { @@ -34,14 +34,26 @@ void send(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Com } }; +#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI) if (is_contiguous(sv)) { - space.fence("fence before send"); + space.fence("fence before GPU-aware `MPI_Send`"); mpi_send_fn(data_handle(sv), span(sv), datatype()); } else { auto args = Packer::pack(space, sv); - space.fence("fence before send"); + space.fence("fence packing before GPU-aware `MPI_Send`"); mpi_send_fn(data_handle(args.view), args.count, args.datatype); } +#else + auto host_sv = KokkosComm::Impl::stage_for(sv); + space.fence("fence host staging before `MPI_Send`"); + if (is_contiguous(host_sv)) { + mpi_send_fn(data_handle(host_sv), span(host_sv), datatype()); + } else { + auto args = Packer::pack(space, host_sv); + space.fence("fence packing before `MPI_Send`"); + mpi_send_fn(data_handle(args.view), args.count, args.datatype); + } +#endif Kokkos::Tools::popRegion(); } diff --git a/unit_tests/CMakeLists.txt b/unit_tests/CMakeLists.txt index 69e25584..adc12824 100644 --- a/unit_tests/CMakeLists.txt +++ b/unit_tests/CMakeLists.txt @@ -51,6 +51,7 @@ kc_add_unit_test(test.core.all_reduce CORE NUM_PES 2 FILES test_main.cpp test_al kc_add_unit_test(test.core.alltoall CORE NUM_PES 2 FILES test_main.cpp test_alltoall.cpp) kc_add_unit_test(test.core.red_op_conv CORE NUM_PES 1 FILES test_main.cpp test_red_op_conversion.cpp) +kc_add_unit_test(test.core.host_staging CORE NUM_PES 1 FILES test_main.cpp test_host_staging.cpp) # --- MPI backend unit tests --- # if(KOKKOSCOMM_ENABLE_MPI) diff --git a/unit_tests/test_host_staging.cpp b/unit_tests/test_host_staging.cpp new file mode 100644 index 00000000..ce374759 --- /dev/null +++ b/unit_tests/test_host_staging.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project + +#include + +#include +#include + +#include +#include +#include + +namespace { + +template +auto test_stage_for_returns_host_accessible_view() -> void { + auto space = E{}; + + constexpr int N = 100; + Kokkos::View view("v", N); + Kokkos::parallel_for( + "fill", Kokkos::RangePolicy(space, 0, N), KOKKOS_LAMBDA(int i) { view(i) = i; }); + space.fence(); + + auto staged = KokkosComm::Impl::stage_for(view); + space.fence(); + + static_assert(Kokkos::SpaceAccessibility::accessible); + + for (int i = 0; i < N; ++i) { + EXPECT_EQ(staged(i), i); + } +} + +TEST(StagingComptimeTest, NeedsStagingTraitIsCorrect) { + using HostView = Kokkos::View; + EXPECT_FALSE(KokkosComm::Impl::needs_staging_v); + +#ifdef KOKKOS_ENABLE_CUDA + using CudaView = Kokkos::View; + EXPECT_TRUE(KokkosComm::Impl::needs_staging_v); + using UVMView = Kokkos::View; + EXPECT_FALSE(KokkosComm::Impl::needs_staging_v); +#endif + +#ifdef KOKKOS_ENABLE_HIP + using HIPView = Kokkos::View; + EXPECT_TRUE(KokkosComm::Impl::needs_staging_v); + using HIPManagedView = Kokkos::View; + EXPECT_FALSE(KokkosComm::Impl::needs_staging_v); +#endif +} + +TEST(StagingTest, StageForPreservesDataPointerForHostViews) { + auto space = Kokkos::DefaultHostExecutionSpace{}; + + constexpr int N = 100; + Kokkos::View host_view("v", N); + + auto staged = KokkosComm::Impl::stage_for(host_view); + space.fence(); + EXPECT_EQ(staged.data(), host_view.data()); +} + +TEST(StagingTest, StageForCreatesIndependentCopyForDeviceViews) { + if constexpr (std::is_same_v) { + GTEST_SKIP() << "Default execution space is on host"; + } else { + auto space = Kokkos::DefaultExecutionSpace{}; + + constexpr int N = 100; + Kokkos::View device_view("v", N); + + auto staged = KokkosComm::Impl::stage_for(device_view); + space.fence(); + EXPECT_NE(reinterpret_cast(staged.data()), reinterpret_cast(device_view.data())); + } +} + +TEST(StagingTest, CopyBackTransfersData) { + auto space = Kokkos::DefaultExecutionSpace{}; + + constexpr int N = 100; + Kokkos::View device_view("v", N); + + space.fence(); + auto staged = KokkosComm::Impl::stage_for(device_view); + Kokkos::DefaultHostExecutionSpace().fence(); + for (int i = 0; i < N; ++i) { + staged(i) = 2 * i; + } + + KokkosComm::Impl::copy_back(space, device_view, staged); + space.fence(); + + auto check = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, device_view); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(check(i), 2 * i); + } +} + +template +class StagingTypedTest : public ::testing::Test { + public: + using Space = T; +}; + +using SpaceTypes = ::testing::Types; +TYPED_TEST_SUITE(StagingTypedTest, SpaceTypes); + +TYPED_TEST(StagingTypedTest, StageForReturnsHostAccessibleView) { + test_stage_for_returns_host_accessible_view(); +} + +} // namespace