@@ -22,8 +22,9 @@ namespace KokkosComm::Experimental {
2222namespace nccl {
2323
2424template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
25- auto reduce (const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_t op, int root, int rank,
26- ncclComm_t comm) -> Request<NcclSpace> {
25+ auto reduce (
26+ const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_t op, int root, int rank, ncclComm_t comm
27+ ) -> Request<NcclSpace> {
2728 using ST = typename SendView::non_const_value_type;
2829 using RT = typename RecvView::non_const_value_type;
2930 using SendPacker = typename Impl::PackTraits<SendView>::packer_type;
@@ -34,13 +35,16 @@ auto reduce(const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_
3435 Request<NcclSpace> req;
3536 if (is_contiguous (sv)) {
3637 if (rank != root and is_contiguous (rv)) {
37- ncclReduce (data_handle (sv), data_handle (rv), span (sv), datatype<NcclSpace, ST>(), op, root, comm,
38- space.cuda_stream ());
38+ ncclReduce (
39+ data_handle (sv), data_handle (rv), span (sv), datatype<NcclSpace, ST>(), op, root, comm, space.cuda_stream ()
40+ );
3941 req.capture_stream_state (space.cuda_stream ());
4042 } else {
4143 auto pckd_rv = RecvPacker::allocate_packed_for (space, " pckd_rv" , rv);
42- ncclReduce (data_handle (sv), data_handle (pckd_rv.view_ ), span (sv), datatype<NcclSpace, ST>(), op, root, comm,
43- space.cuda_stream ());
44+ ncclReduce (
45+ data_handle (sv), data_handle (pckd_rv.view_ ), span (sv), datatype<NcclSpace, ST>(), op, root, comm,
46+ space.cuda_stream ()
47+ );
4448 req.capture_stream_state (space.cuda_stream ());
4549 req.add_callback ([space, rv, pckd_rv]() {
4650 RecvPacker::unpack_into (space, rv, pckd_rv.view_ );
@@ -50,13 +54,17 @@ auto reduce(const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_
5054 } else {
5155 auto pckd_sv = SendPacker::pack (space, " pckd_sv" , sv);
5256 if (rank != root and is_contiguous (rv)) {
53- ncclReduce (data_handle (pckd_sv.view_ ), data_handle (rv), pckd_sv.count_ , pckd_sv.datatype_ , op, root, comm,
54- space.cuda_stream ());
57+ ncclReduce (
58+ data_handle (pckd_sv.view_ ), data_handle (rv), pckd_sv.count_ , pckd_sv.datatype_ , op, root, comm,
59+ space.cuda_stream ()
60+ );
5561 req.capture_stream_state (space.cuda_stream ());
5662 } else {
5763 auto pckd_rv = RecvPacker::allocate_packed_for (space, " pckd_rv" , rv);
58- ncclReduce (data_handle (pckd_sv.view_ ), data_handle (pckd_rv.view_ ), pckd_sv.count_ , pckd_sv.datatype_ , op, root,
59- comm, space.cuda_stream ());
64+ ncclReduce (
65+ data_handle (pckd_sv.view_ ), data_handle (pckd_rv.view_ ), pckd_sv.count_ , pckd_sv.datatype_ , op, root, comm,
66+ space.cuda_stream ()
67+ );
6068 req.capture_stream_state (space.cuda_stream ());
6169 req.add_callback ([space, rv, pckd_rv]() {
6270 RecvPacker::unpack_into (space, rv, pckd_rv.view_ );
@@ -77,9 +85,9 @@ namespace Impl {
7785
7886template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
7987struct Reduce <SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
80- static auto execute (Handle< Kokkos::Cuda, NcclSpace >& h, const SendView sv, RecvView rv, int root)
88+ static auto execute (Communicator<NcclSpace, Kokkos::Cuda>& h, const SendView sv, RecvView rv, int root)
8189 -> Request<NcclSpace> {
82- return nccl::reduce (h.space (), sv, rv, reduction_op<NcclSpace, RedOp>(), root, h.rank (), h.comm ());
90+ return nccl::reduce (h.exec (), sv, rv, reduction_op<NcclSpace, RedOp>(), root, h.rank (), h.comm ());
8391 }
8492};
8593
0 commit comments