Skip to content

Commit 40d9bd8

Browse files
committed
refactor(nccl): propagate Communicator changes to NCCL backend
Signed-off-by: Gabriel Dos Santos <gabriel.dossantos@cea.fr>
1 parent c7702ae commit 40d9bd8

File tree

8 files changed

+58
-41
lines changed

8 files changed

+58
-41
lines changed

src/KokkosComm/nccl/allgather.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@ template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvVi
2424
auto allgather(const ExecSpace& space, const SendView& sv, const RecvView& rv, ncclComm_t comm) -> Request<NcclSpace> {
2525
using ST = typename SendView::non_const_value_type;
2626
using RT = typename RecvView::non_const_value_type;
27-
static_assert(std::is_same_v<ST, RT>,
28-
"KokkosComm::Experimental::nccl::allgather: View value types must be identical");
27+
static_assert(
28+
std::is_same_v<ST, RT>, "KokkosComm::Experimental::nccl::allgather: View value types must be identical"
29+
);
2930
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::allgather");
3031

3132
Request<NcclSpace> req;
3233
if (KC::is_contiguous(sv) and KC::is_contiguous(rv)) {
33-
ncclAllGather(KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), datatype<NcclSpace, ST>(), comm,
34-
space.cuda_stream());
34+
ncclAllGather(
35+
KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), datatype<NcclSpace, ST>(), comm, space.cuda_stream()
36+
);
3537
req.capture_stream_state(space.cuda_stream());
3638
} else {
3739
Kokkos::abort("KokkosComm::Experimental::nccl::allgather: unimplemented for non-contiguous views");
@@ -48,8 +50,8 @@ namespace Impl {
4850

4951
template <KokkosView SendView, KokkosView RecvView>
5052
struct AllGather<SendView, RecvView, Kokkos::Cuda, NcclSpace> {
51-
static auto execute(Handle<Kokkos::Cuda, NcclSpace>& h, const SendView sv, RecvView rv) -> Request<NcclSpace> {
52-
return nccl::allgather(h.space(), sv, rv, h.comm());
53+
static auto execute(Communicator<NcclSpace, Kokkos::Cuda>& h, const SendView sv, RecvView rv) -> Request<NcclSpace> {
54+
return nccl::allgather(h.exec(), sv, rv, h.comm());
5355
}
5456
};
5557

src/KokkosComm/nccl/allreduce.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,16 @@ auto allreduce(const ExecSpace& space, const SendView& sv, const RecvView& rv, n
2626
-> Request<NcclSpace> {
2727
using ST = typename SendView::non_const_value_type;
2828
using RT = typename RecvView::non_const_value_type;
29-
static_assert(std::is_same_v<ST, RT>,
30-
"KokkosComm::Experimental::nccl::allreduce: View value types must be identical");
29+
static_assert(
30+
std::is_same_v<ST, RT>, "KokkosComm::Experimental::nccl::allreduce: View value types must be identical"
31+
);
3132
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::allreduce");
3233

3334
Request<NcclSpace> req;
3435
if (KC::is_contiguous(sv) and KC::is_contiguous(rv)) {
35-
ncclAllReduce(KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), datatype<NcclSpace, ST>(), op, comm,
36-
space.cuda_stream());
36+
ncclAllReduce(
37+
KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), datatype<NcclSpace, ST>(), op, comm, space.cuda_stream()
38+
);
3739
req.capture_stream_state(space.cuda_stream());
3840
} else {
3941
Kokkos::abort("KokkosComm::Experimental::nccl::allreduce: unimplemented for non-contiguous Views");
@@ -50,8 +52,8 @@ namespace Impl {
5052

5153
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
5254
struct AllReduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
53-
static auto execute(Handle<Kokkos::Cuda, NcclSpace>& h, const SendView sv, RecvView rv) -> Request<NcclSpace> {
54-
return nccl::allreduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), h.comm());
55+
static auto execute(Communicator<NcclSpace, Kokkos::Cuda>& h, const SendView sv, RecvView rv) -> Request<NcclSpace> {
56+
return nccl::allreduce(h.exec(), sv, rv, reduction_op<NcclSpace, RedOp>(), h.comm());
5557
}
5658
};
5759

src/KokkosComm/nccl/alltoall.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ namespace Impl {
5858

5959
template <KokkosView SendView, KokkosView RecvView>
6060
struct AllToAll<SendView, RecvView, Kokkos::Cuda, NcclSpace> {
61-
static auto execute(Handle<Kokkos::Cuda, NcclSpace>& h, const SendView sv, RecvView rv, int count)
61+
static auto execute(Communicator<NcclSpace, Kokkos::Cuda>& h, const SendView sv, RecvView rv, int count)
6262
-> Request<NcclSpace> {
63-
return nccl::alltoall(h.space(), sv, rv, count, h.comm());
63+
return nccl::alltoall(h.exec(), sv, rv, count, h.comm());
6464
}
6565
};
6666

src/KokkosComm/nccl/broadcast.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ namespace KC = KokkosComm;
2323
template <KokkosView View>
2424
auto broadcast(const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm) -> Request<NcclSpace> {
2525
using T = typename View::non_const_value_type;
26-
static_assert(KC::rank<View>() <= 1,
27-
"KokkosComm::Experimental::nccl::broadcast: Views with rank higher than 1 are not supported");
26+
static_assert(
27+
KC::rank<View>() <= 1,
28+
"KokkosComm::Experimental::nccl::broadcast: Views with rank higher than 1 are not supported"
29+
);
2830
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::broadcast");
2931

3032
Request<NcclSpace> req;
@@ -45,8 +47,8 @@ namespace Impl {
4547

4648
template <KokkosView View>
4749
struct Broadcast<View, Kokkos::Cuda, NcclSpace> {
48-
static auto execute(Handle<Kokkos::Cuda, NcclSpace>& h, View v, int root) -> Request<NcclSpace> {
49-
return nccl::broadcast(h.space(), v, root, h.comm());
50+
static auto execute(Communicator<NcclSpace, Kokkos::Cuda>& h, View v, int root) -> Request<NcclSpace> {
51+
return nccl::broadcast(h.exec(), v, root, h.comm());
5052
}
5153
};
5254

src/KokkosComm/nccl/communicator.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Communicator<Experimental::NcclSpace, Kokkos::Cuda> {
2828
/// Defaults `exec` to `Kokkos::Cuda`.
2929
///
3030
/// The returned communicator does not own the underlying handle, and the user is responsible for destroying it.
31-
[[nodiscard]] static auto from_raw(communicator_type comm, const execution_space& exec = execution_space{}, ) noexcept
31+
[[nodiscard]] static auto from_raw(communicator_type comm, const execution_space& exec = execution_space{}) noexcept
3232
-> std::optional<Communicator<communication_space, execution_space>> {
3333
if (comm == nullptr) {
3434
return std::nullopt;
@@ -42,16 +42,17 @@ class Communicator<Experimental::NcclSpace, Kokkos::Cuda> {
4242
/// value of `key`. All processes with the same value of `color` join the same communicator.
4343
/// A process that passes `NCCL_SPLIT_NOCOLOR` as `color` will not join a new communicator and `nullopt` is returned.
4444
[[nodiscard]] static auto split(
45-
const communicator_type comm, int color, int key, const execution_space& exec = execution_space{},
46-
) -> std::optional<Communicator<communication_space, execution_space>> {
45+
const communicator_type comm, int color, int key, const execution_space& exec = execution_space{}
46+
) noexcept -> std::optional<Communicator<communication_space, execution_space>> {
4747
communicator_type new_comm;
4848
ncclCommSplit(comm, color, key, &new_comm, nullptr);
4949
if (new_comm == nullptr) {
5050
return std::nullopt;
5151
}
5252
return Communicator<communication_space, execution_space>(new_comm, exec, true);
5353
}
54-
[[nodiscard]] auto split(int color, int key) -> std::optional<Communicator<execution_space, communication_space>> {
54+
[[nodiscard]] auto split(int color, int key) noexcept
55+
-> std::optional<Communicator<communication_space, execution_space>> {
5556
return Communicator::split(comm_, color, key, exec_);
5657
}
5758

@@ -64,7 +65,7 @@ class Communicator<Experimental::NcclSpace, Kokkos::Cuda> {
6465
ncclCommUserRank(comm, &rank);
6566
return Communicator::split(comm, 0, rank, exec_);
6667
}
67-
[[nodiscard]] auto duplicate() -> std::optional<Communicator<execution_space, communication_space>> {
68+
[[nodiscard]] auto duplicate() noexcept -> std::optional<Communicator<communication_space, execution_space>> {
6869
return Communicator::duplicate(comm_, exec_);
6970
}
7071

src/KokkosComm/nccl/recv.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ auto recv(const ExecSpace& space, RecvView& rv, int peer, ncclComm_t comm) -> Re
3131
using Packer = typename Impl::PackTraits<RecvView>::packer_type;
3232
auto pckd_rv = Packer::allocate_packed_for(space, "pckd_rv", rv);
3333
KC_NCCL_CHECK(
34-
ncclRecv(data_handle(pckd_rv.view_), pckd_rv.count_, pckd_rv.datatype_, peer, comm, space.cuda_stream()));
34+
ncclRecv(data_handle(pckd_rv.view_), pckd_rv.count_, pckd_rv.datatype_, peer, comm, space.cuda_stream())
35+
);
3536
req.capture_stream_state(space.cuda_stream());
3637
req.add_callback([space, rv, pckd_rv]() {
3738
Packer::unpack_into(space, rv, pckd_rv.view_);
@@ -49,9 +50,9 @@ namespace Impl {
4950

5051
template <KokkosView RecvView>
5152
struct Recv<RecvView, Kokkos::Cuda, Experimental::NcclSpace> {
52-
static auto execute(Handle<Kokkos::Cuda, Experimental::NcclSpace>& h, RecvView sv, int peer)
53+
static auto execute(Communicator<Experimental::NcclSpace, Kokkos::Cuda>& h, RecvView sv, int peer)
5354
-> Request<Experimental::NcclSpace> {
54-
return Experimental::nccl::recv(h.space(), sv, peer, h.comm());
55+
return Experimental::nccl::recv(h.exec(), sv, peer, h.comm());
5556
}
5657
};
5758

src/KokkosComm/nccl/reduce.hpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ namespace KokkosComm::Experimental {
2222
namespace nccl {
2323

2424
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
25-
auto reduce(const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_t op, int root, int rank,
26-
ncclComm_t comm) -> Request<NcclSpace> {
25+
auto reduce(
26+
const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_t op, int root, int rank, ncclComm_t comm
27+
) -> Request<NcclSpace> {
2728
using ST = typename SendView::non_const_value_type;
2829
using RT = typename RecvView::non_const_value_type;
2930
using SendPacker = typename Impl::PackTraits<SendView>::packer_type;
@@ -34,13 +35,16 @@ auto reduce(const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_
3435
Request<NcclSpace> req;
3536
if (is_contiguous(sv)) {
3637
if (rank != root and is_contiguous(rv)) {
37-
ncclReduce(data_handle(sv), data_handle(rv), span(sv), datatype<NcclSpace, ST>(), op, root, comm,
38-
space.cuda_stream());
38+
ncclReduce(
39+
data_handle(sv), data_handle(rv), span(sv), datatype<NcclSpace, ST>(), op, root, comm, space.cuda_stream()
40+
);
3941
req.capture_stream_state(space.cuda_stream());
4042
} else {
4143
auto pckd_rv = RecvPacker::allocate_packed_for(space, "pckd_rv", rv);
42-
ncclReduce(data_handle(sv), data_handle(pckd_rv.view_), span(sv), datatype<NcclSpace, ST>(), op, root, comm,
43-
space.cuda_stream());
44+
ncclReduce(
45+
data_handle(sv), data_handle(pckd_rv.view_), span(sv), datatype<NcclSpace, ST>(), op, root, comm,
46+
space.cuda_stream()
47+
);
4448
req.capture_stream_state(space.cuda_stream());
4549
req.add_callback([space, rv, pckd_rv]() {
4650
RecvPacker::unpack_into(space, rv, pckd_rv.view_);
@@ -50,13 +54,17 @@ auto reduce(const ExecSpace& space, const SendView& sv, RecvView& rv, ncclRedOp_
5054
} else {
5155
auto pckd_sv = SendPacker::pack(space, "pckd_sv", sv);
5256
if (rank != root and is_contiguous(rv)) {
53-
ncclReduce(data_handle(pckd_sv.view_), data_handle(rv), pckd_sv.count_, pckd_sv.datatype_, op, root, comm,
54-
space.cuda_stream());
57+
ncclReduce(
58+
data_handle(pckd_sv.view_), data_handle(rv), pckd_sv.count_, pckd_sv.datatype_, op, root, comm,
59+
space.cuda_stream()
60+
);
5561
req.capture_stream_state(space.cuda_stream());
5662
} else {
5763
auto pckd_rv = RecvPacker::allocate_packed_for(space, "pckd_rv", rv);
58-
ncclReduce(data_handle(pckd_sv.view_), data_handle(pckd_rv.view_), pckd_sv.count_, pckd_sv.datatype_, op, root,
59-
comm, space.cuda_stream());
64+
ncclReduce(
65+
data_handle(pckd_sv.view_), data_handle(pckd_rv.view_), pckd_sv.count_, pckd_sv.datatype_, op, root, comm,
66+
space.cuda_stream()
67+
);
6068
req.capture_stream_state(space.cuda_stream());
6169
req.add_callback([space, rv, pckd_rv]() {
6270
RecvPacker::unpack_into(space, rv, pckd_rv.view_);
@@ -77,9 +85,9 @@ namespace Impl {
7785

7886
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp>
7987
struct Reduce<SendView, RecvView, RedOp, Kokkos::Cuda, NcclSpace> {
80-
static auto execute(Handle<Kokkos::Cuda, NcclSpace>& h, const SendView sv, RecvView rv, int root)
88+
static auto execute(Communicator<NcclSpace, Kokkos::Cuda>& h, const SendView sv, RecvView rv, int root)
8189
-> Request<NcclSpace> {
82-
return nccl::reduce(h.space(), sv, rv, reduction_op<NcclSpace, RedOp>(), root, h.rank(), h.comm());
90+
return nccl::reduce(h.exec(), sv, rv, reduction_op<NcclSpace, RedOp>(), root, h.rank(), h.comm());
8391
}
8492
};
8593

src/KokkosComm/nccl/send.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ auto send(const ExecSpace& space, const SendView& sv, int peer, ncclComm_t comm)
3131
using Packer = typename Impl::PackTraits<SendView>::packer_type;
3232
auto pckd_sv = Packer::pack(space, "pckd_sv", sv);
3333
KC_NCCL_CHECK(
34-
ncclSend(data_handle(pckd_sv.view_), pckd_sv.count_, pckd_sv.datatype_, peer, comm, space.cuda_stream()));
34+
ncclSend(data_handle(pckd_sv.view_), pckd_sv.count_, pckd_sv.datatype_, peer, comm, space.cuda_stream())
35+
);
3536
req.capture_stream_state(space.cuda_stream());
3637
req.extend_view_lifetime(pckd_sv.view_);
3738
}
@@ -46,9 +47,9 @@ namespace Impl {
4647

4748
template <KokkosView SendView>
4849
struct Send<SendView, Kokkos::Cuda, Experimental::NcclSpace> {
49-
static auto execute(Handle<Kokkos::Cuda, Experimental::NcclSpace>& h, SendView sv, int peer)
50+
static auto execute(Communicator<Experimental::NcclSpace, Kokkos::Cuda>& h, SendView sv, int peer)
5051
-> Request<Experimental::NcclSpace> {
51-
return Experimental::nccl::send(h.space(), sv, peer, h.comm());
52+
return Experimental::nccl::send(h.exec(), sv, peer, h.comm());
5253
}
5354
};
5455

0 commit comments

Comments
 (0)