1212#include " mpi_space.hpp"
1313#include " req.hpp"
1414
15+ #include < KokkosComm/impl/host_staging.hpp>
1516#include " impl/error_handling.hpp"
1617
1718namespace KokkosComm {
@@ -23,12 +24,20 @@ auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req
2324 Kokkos::Tools::pushRegion (" KokkosComm::mpi::ibroadcast" );
2425 fail_if (!is_contiguous (v), " KokkosComm::mpi::ibroadcast: unimplemented for non-contiguous views" );
2526
26- // Sync: Work in space may have been used to produce view data.
27- space.fence (" fence before non-blocking broadcast" );
28-
2927 Req<MpiSpace> req;
30- MPI_Ibcast (data_handle (v), span (v), datatype<MpiSpace, T>, root, comm, &req.mpi_request ());
28+ #if defined(KOKKOSCOMM_ENABLE_GPU_AWARE_MPI)
29+ // Sync: Work in space may have been used to produce view data.
30+ space.fence (" fence before GPU-aware `MPI_Ibcast`" );
31+ MPI_Ibcast (data_handle (v), span (v), datatype<MpiSpace, T>(), root, comm, &req.mpi_request ());
3132 req.extend_view_lifetime (v);
33+ #else
34+ auto host_v = KokkosComm::Impl::stage_for (v);
35+ // Sync: Ensure that `host_v` is done being copied on the host
36+ space.fence (" fence before non-blocking broadcast" );
37+ MPI_Ibcast (data_handle (host_v), span (host_v), datatype<MpiSpace, T>(), root, comm, &req.mpi_request ());
38+ // Implicitly extends lifetimes of `host_v` and `v` due to lambda capture
39+ req.call_after_mpi_wait ([=]() { KokkosComm::Impl::copy_back (space, v, host_v); });
40+ #endif
3241
3342 Kokkos::Tools::popRegion ();
3443 return req;
@@ -48,7 +57,7 @@ void broadcast(ExecSpace const& space, View const& v, int root, MPI_Comm comm) {
4857 auto host_v = KokkosComm::Impl::stage_for (v);
4958 // Sync: Ensure that `host_v` is done being copied on the host
5059 space.fence (" fence before non-blocking broadcast" );
51- MPI_Bcast (data_handle (host_v), span (host_v), datatype<MpiSpace, T>, root, comm);
60+ MPI_Bcast (data_handle (host_v), span (host_v), datatype<MpiSpace, T>() , root, comm);
5261 KokkosComm::Impl::copy_back (space, v, host_v);
5362#endif
5463
0 commit comments