Skip to content

Commit 05ffb54

Browse files
committed
feat(mpi): add host staging for ibroadcast
Signed-off-by: Gabriel Dos Santos <gabriel.dossantos@cea.fr>
1 parent e5f4026 commit 05ffb54

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/KokkosComm/mpi/broadcast.hpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mpi_space.hpp"
1313
#include "req.hpp"
1414

15+
#include <KokkosComm/impl/host_staging.hpp>
1516
#include "impl/error_handling.hpp"
1617

1718
namespace KokkosComm {
@@ -23,12 +24,20 @@ auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req
2324
Kokkos::Tools::pushRegion("KokkosComm::mpi::ibroadcast");
2425
fail_if(!is_contiguous(v), "KokkosComm::mpi::ibroadcast: unimplemented for non-contiguous views");
2526

26-
// Sync: Work in space may have been used to produce view data.
27-
space.fence("fence before non-blocking broadcast");
28-
2927
Req<MpiSpace> req;
30-
MPI_Ibcast(data_handle(v), span(v), datatype<MpiSpace, T>, root, comm, &req.mpi_request());
28+
#if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI)
29+
// Sync: Work in space may have been used to produce view data.
30+
space.fence("fence before GPU-aware `MPI_Ibcast`");
31+
MPI_Ibcast(data_handle(v), span(v), datatype<MpiSpace, T>(), root, comm, &req.mpi_request());
3132
req.extend_view_lifetime(v);
33+
#else
34+
auto host_v = KokkosComm::Impl::stage_for(v);
35+
// Sync: Ensure that `host_v` is done being copied on the host
36+
space.fence("fence before non-blocking broadcast");
37+
MPI_Ibcast(data_handle(host_v), span(host_v), datatype<MpiSpace, T>(), root, comm, &req.mpi_request());
38+
// Implicitly extends lifetimes of `host_v` and `v` due to lambda capture
39+
req.call_after_mpi_wait([=]() { KokkosComm::Impl::copy_back(space, v, host_v); });
40+
#endif
3241

3342
Kokkos::Tools::popRegion();
3443
return req;
@@ -48,7 +57,7 @@ void broadcast(ExecSpace const& space, View const& v, int root, MPI_Comm comm) {
4857
auto host_v = KokkosComm::Impl::stage_for(v);
4958
// Sync: Ensure that `host_v` is done being copied on the host
5059
space.fence("fence before non-blocking broadcast");
51-
MPI_Bcast(data_handle(host_v), span(host_v), datatype<MpiSpace, T>, root, comm);
60+
MPI_Bcast(data_handle(host_v), span(host_v), datatype<MpiSpace, T>(), root, comm);
5261
KokkosComm::Impl::copy_back(space, v, host_v);
5362
#endif
5463

0 commit comments

Comments
 (0)