Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d740238
refactor(core): rename `Handle` to `Communicator`
dssgabriel Mar 2, 2026
3e637e1
chore(core): propagate file name change
dssgabriel Mar 2, 2026
498c3b7
feat(core): add `Rank` strong type
dssgabriel Mar 2, 2026
bcf9857
docs(core): start `Communicator` doc update
dssgabriel Mar 2, 2026
11eab36
style: don't bin pack arguments w/ clang-format
dssgabriel Mar 2, 2026
1b512ce
wip: communicator impl
dssgabriel Mar 2, 2026
e482389
revert(core): remove `Rank` strong type and `Color`/`Key` alias
dssgabriel Mar 5, 2026
db0153f
refactor: rework communicator impls
dssgabriel Mar 5, 2026
c2d87fe
docs(core): document communicators
dssgabriel Mar 5, 2026
c7702ae
refactor(mpi): propagate Communicator changes to MPI backend
dssgabriel Mar 6, 2026
40d9bd8
refactor(nccl): propagate Communicator changes to NCCL backend
dssgabriel Mar 6, 2026
f98e15a
tests: refactor all tests to pass with the new Communicators
dssgabriel Mar 6, 2026
fc8687f
tests(core): enable missing `broadcast` and `all-gather` tests
dssgabriel Mar 6, 2026
5f6128e
fix(perf_tests): forward args by reference in `do_iteration`
dssgabriel Mar 7, 2026
30badf1
fix(nccl): can't use member variables in factory `duplicate`
dssgabriel Mar 7, 2026
b01d939
refactor(nccl): avoid recomputing the rank in member `duplicate`
dssgabriel Mar 7, 2026
c25d889
fix(nccl): missed `using` decl with wrong type name
dssgabriel Mar 7, 2026
e1832ae
tests(nccl): fix missed P2P test
dssgabriel Mar 7, 2026
d6461fb
tests(nccl): fix missed all-gather test
dssgabriel Mar 7, 2026
ae80875
tests(core): add a set of unit tests for Communicators
dssgabriel Mar 9, 2026
a621a5b
docs(core): add missing accessors for Communicator specializatn
dssgabriel Mar 9, 2026
266c409
docs: propagate `Communicator` changes everywhere
dssgabriel Mar 9, 2026
b49867d
docs(core): fix indent to prevent Sphinx from seeing duplicates
dssgabriel Mar 9, 2026
63bfc20
refactor: default t-params for Communicators
dssgabriel Mar 17, 2026
f76ed39
fix(core): define `DefaultCommunicationSpace` once and only once
dssgabriel Mar 17, 2026
b959aa0
refactor(mpi): use more descriptive t-param names
dssgabriel Mar 17, 2026
e594754
docs: use defaulted t-params for communicators in code examples
dssgabriel Mar 17, 2026
72038ae
refactor(Comms): append `_from_raw` to static member functions
dssgabriel Mar 20, 2026
a24e308
tests(core): add missing Communicator tests
dssgabriel Mar 20, 2026
9447ef8
docs(api): fix NCCL Communicator template specialization parameter
dssgabriel Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ SortIncludes: false
AlignConsecutiveAssignments: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortIfStatementsOnASingleLine: true
AlignAfterOpenBracket: BlockIndent
BinPackParameters: false
ColumnLimit: 120
16 changes: 11 additions & 5 deletions docs/api/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,51 @@ Kokkos Comm-specific concepts

.. cpp:type:: T::communication_space

.. cpp:type:: T::handle_type
.. cpp:type:: T::communicator_type

.. cpp:type:: T::request_type

.. cpp:type:: T::datatype_type

.. cpp:type:: T::reduction_op_type

.. cpp:type:: T::size_type

.. cpp:type:: T::rank_type

Types implementing the ``CommunicationSpace`` concept
-----------------------------------------------------
"""""""""""""""""""""""""""""""""""""""""""""""""""""

.. cpp:class:: MpiSpace

.. cpp:type:: communication_space = MpiSpace

.. cpp:type:: handle_type = MPI_Comm
.. cpp:type:: communicator_type = MPI_Comm

.. cpp:type:: request_type = MPI_Request

.. cpp:type:: datatype_type = MPI_Datatype

.. cpp:type:: reduction_op_type = MPI_Op

.. cpp:type:: size_type = int

.. cpp:type:: rank_type = int

.. cpp:class:: Experimental::NcclSpace

.. cpp:type:: communication_space = NcclSpace

.. cpp:type:: handle_type = ncclComm_t
.. cpp:type:: communicator_type = ncclComm_t

.. cpp:type:: request_type = cudaEvent_t

.. cpp:type:: datatype_type = ncclDataType_t

.. cpp:type:: reduction_op_type = ncclRedOp_t

.. cpp:type:: size_type = int

.. cpp:type:: rank_type = int


Expand All @@ -62,7 +68,7 @@ Types implementing the ``CommunicationSpace`` concept
Specifies that a type ``T`` is a Kokkos Comm reduction operator.

Types implementing the ``ReductionOperator`` concept
----------------------------------------------------
""""""""""""""""""""""""""""""""""""""""""""""""""""

.. cpp:class:: BAnd

Expand Down
315 changes: 245 additions & 70 deletions docs/api/core.rst

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions docs/api/core_recv.cpp
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
#include "KokkosComm/KokkosComm.hpp"
#include <Kokkos_Core.hpp>
#include <KokkosComm/KokkosComm.hpp>

// Define the execution space and transport
using ExecSpace = Kokkos::DefaultExecutionSpace;
using CommSpace = DefaultCommunicationSpace;
// Define the communication and execution spaces
using Co = KokkosComm::DefaultCommunicationSpace;
using Ex = Kokkos::DefaultExecutionSpace;

// Source rank
int src = 1;
int src_rank = 1;

// Create a handle
KokkosComm::Handle<> handle; // Same as Handle<Execspace, CommSpace>
// Create a communicator
auto comm = KokkosComm::Communicator<Co, Ex>::duplicate(raw_comm_handle, exec_space);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later, we should deduce the correct template parameters from the arguments. Or do we make a function that duplicates a Communicator.

Copy link
Copy Markdown
Collaborator Author

@dssgabriel dssgabriel Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we cannot deduce without taking an MpiSpace tag as argument, since we cannot define CTAD-like rules for static member functions.
An alternative could be the free function you mention, but I don't know about that. The name would need to reflect that it operates on a communicator, which isn't as clean as having a static member function.

One easy improvement we can make is to default the template parameters of the Communicator:

template <
  CommunicationSpace Comm = DefaultCommunicationSpace,
  KokkosExecutionSpace Exec = Kokkos::DefaultExecSpace>
struct Communicator;

which allows for simpler "default" code:

auto comm = KokkosComm::Communicator<>::duplicate(raw_comm_handle, exec);


// Allocate a view to receive the data
Kokkos::View<double*> data("recv_view", 100);

// Initiate a non-blocking receive with a handle
auto req1 = recv(handle, data, src);
auto req1 = KokkosComm::recv(comm, data, src_rank);

// Initiate a non-blocking receive with a default handle
auto req2 = recv(data, src);
// Simulate a blocking receive by waiting immediately
KokkosComm::recv(comm, data, src_rank).wait();

// Wait for the requests to complete (assuming a wait function exists)
// Wait for a requests to complete
KokkosComm::wait(req1);
KokkosComm::wait(req2);
31 changes: 17 additions & 14 deletions docs/api/core_send.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
#include "KokkosComm/KokkosComm.hpp"
#include <Kokkos_Core.hpp>
#include <KokkosComm/KokkosComm.hpp>

// Define the execution space and transport
using ExecSpace = Kokkos::DefaultExecutionSpace;
using CommSpace = DefaultCommunicationSpace;
// Define the communication and execution spaces
using Co = KokkosComm::DefaultCommunicationSpace;
using Ex = Kokkos::DefaultExecutionSpace;

// Create an execution space instance
auto exec = Ex();
// Create a communicator
auto comm = KokkosComm::Communicator<Co, Ex>::duplicate(raw_comm_handle, exec);

// Create a Kokkos view
Kokkos::View<double*> data("data", 100);

// Fill the view with some data
Kokkos::parallel_for("fill_data", Kokkos::RangePolicy<ExecSpace>(0, 100), KOKKOS_LAMBDA(int i) {
Kokkos::parallel_for("fill_data", Kokkos::RangePolicy(exec, 0, 100), KOKKOS_LAMBDA(int i) {
data(i) = static_cast<double>(i);
});
exec.fence();

// Destination rank
int dest = 1;

// Create a handle
KokkosComm::Handle<> handle; // Same as Handle<Execspace, CommSpace>
int dst_rank = 1;

// Initiate a non-blocking send with a handle
auto req1 = send(handle, data, dest);
auto req1 = KokkosComm::send(comm, data, dst_rank);

// Initiate a non-blocking send with a default handle
auto req2 = send(data, dest);
// Simulate a blocking send by waiting immediately
KokkosComm::send(comm, data, dst_rank).wait();

// Wait for the requests to complete (assuming a wait function exists)
// Wait for a requests to complete
KokkosComm::wait(req1);
KokkosComm::wait(req2);
2 changes: 1 addition & 1 deletion docs/api/mpi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Point-to-point


.. cpp:function:: template <CommMode SendMode, KokkosExecutionSpace ExecSpace, KokkosView SendView> \
auto isend(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag) -> Request<MpiSpace>
auto isend(Communicator<MpiSpace, ExecSpace> &h, const SendView &sv, int dest, int tag) -> Request<MpiSpace>

Initiates a non-blocking send operation.

Expand Down
38 changes: 8 additions & 30 deletions docs/dev/impl_comm_space.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ For example, for the MPI communication space, we define the following:
To let core API functions know that your communication space is something KokkosComm can use to dispatch messages, you also need to declare the ``Impl::is_communication_space`` specialization using the ``CommunicationSpace`` concept.


Partial specialization of ``Handle``
=====================================
Partial specialization of ``Communicator``
==========================================

.. attention:: Section in construction...

Expand All @@ -49,13 +49,13 @@ For example, for the MPI communication space handle, we define the following:
namespace KokkosComm {

template <KokkosExecutionSpace ExecSpace>
class Handle<ExecSpace, Mpi> { /* ... */ };
class Communicator<MpiSpace, ExecSpace> { /* ... */ };

} // end KokkosComm


Partial specialization of ``Req``
===================================
Partial specialization of ``Request``
=====================================

.. attention:: Section in construction...

Expand Down Expand Up @@ -92,7 +92,7 @@ The core API functions are actually implemented by partial specializations of st


In the above, ``CommSpace`` is a type that represents the communication space implementation.
For example, for the MPI communication space, we create a partial specialization of that struct template (notice fewer template parameters and the use of the ``Mpi`` "tag" struct):
For example, for the MPI communication space, we create a partial specialization of that struct template (notice fewer template parameters and the use of the ``MpiSpace`` "tag" struct):

.. code-block:: cpp

Expand Down Expand Up @@ -126,7 +126,7 @@ An asynchronous/non-blocking message send:

template <KokkosView SendView, KokkosExecutionSpace ExecSpace>
struct Send<SendView, ExecSpace, MyCommSpace> {
static auto execute(Handle<ExecSpace, MyCommSpace> &h, const SendView &sv, int dest) -> Request<MyCommSpace> {
static auto execute(Communicator<MyCommSpace, ExecSpace> &h, const SendView &sv, int dest) -> Request<MyCommSpace> {
// actual implementation of `send` with your communication backend
}
};
Expand All @@ -148,31 +148,9 @@ An asynchronous/non-blocking message receive.

template <KokkosView RecvView, KokkosExecutionSpace ExecSpace>
struct Recv<RecvView, ExecSpace, MyCommSpace> {
static auto execute(Handle<ExecSpace, MyCommSpace> &h, const RecvView &sv, int src) -> Request<MyCommSpace> {
static auto execute(Communicator<MyCommSpace, ExecSpace> &h, const RecvView &sv, int src) -> Request<MyCommSpace> {
// actual implementation of `recv` with your communication backend
}
};

} // end KokkosComm::Impl


``Barrier`` concept
^^^^^^^^^^^^^^^^^^^

A global barrier.

.. code-block:: cpp

#include "KokkosComm/concepts.hpp"
#include "my_comm_space.hpp"

namespace KokkosComm::Impl {

template <KokkosExecutionSpace ExecSpace>
struct Recv<ExecSpace, MyCommSpace> {
static auto execute(Handle<ExecSpace, MyCommSpace> &&h) -> Request<MyCommSpace> {
// actual implementation of `barrier` with your communication backend
}
};

} // end KokkosComm::Impl
14 changes: 9 additions & 5 deletions perf_tests/mpi/test_2d_halo.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#include "KokkosComm/fwd.hpp"
#include "KokkosComm/mpi/mpi_space.hpp"
#include "test_utils.hpp"

#include <iostream>
Expand All @@ -10,9 +12,10 @@
void noop(benchmark::State, MPI_Comm) {}

template <typename Space, typename View>
void send_recv(benchmark::State&, MPI_Comm comm, const Space& space, int nx, int ny, int rx, int ry, int rs,
const View& v) {
KokkosComm::Handle<> h{space, comm};
void send_recv(
benchmark::State&, MPI_Comm comm, const Space& space, int nx, int ny, int rx, int ry, int rs, const View& v
) {
auto h = KokkosComm::Communicator<KokkosComm::MpiSpace, Space>::from_raw(comm, space).value();

// 2D index of nbrs in minus and plus direction (periodic)
const int xm1 = (rx + rs - 1) % rs;
Expand Down Expand Up @@ -73,8 +76,9 @@ void benchmark_2dhalo(benchmark::State& state) {
// grid of elements, each with 3 properties, and a radius-1 halo
grid_type grid("", nx + 2, ny + 2, nprops);
while (state.KeepRunning()) {
do_iteration(state, MPI_COMM_WORLD, send_recv<Kokkos::DefaultExecutionSpace, grid_type>, space, nx, ny, rx, ry,
rs, grid);
do_iteration(
state, MPI_COMM_WORLD, send_recv<Kokkos::DefaultExecutionSpace, grid_type>, space, nx, ny, rx, ry, rs, grid
);
}
} else {
while (state.KeepRunning()) {
Expand Down
41 changes: 28 additions & 13 deletions perf_tests/mpi/test_osu_latency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
// Copyright (c) 2002-2024 the Network-Based Computing Laboratory
// (NBCL), The Ohio State University.

#include "KokkosComm/mpi/mpi_space.hpp"
#include "test_utils.hpp"

#include <KokkosComm/KokkosComm.hpp>

template <typename Space, typename View>
void osu_latency_Kokkos_Comm_sendrecv(benchmark::State &, MPI_Comm, KokkosComm::Handle<> &h, const View &v) {
void osu_latency_Kokkos_Comm_sendrecv(
benchmark::State &, MPI_Comm, KokkosComm::Communicator<KokkosComm::MpiSpace, Space> &h, const View &v
) {
if (h.rank() == 0) {
KokkosComm::wait(KokkosComm::send(h, v, 1));
} else if (h.rank() == 1) {
Expand All @@ -19,7 +22,10 @@ void osu_latency_Kokkos_Comm_sendrecv(benchmark::State &, MPI_Comm, KokkosComm::
}

void benchmark_osu_latency_KokkosComm_sendrecv(benchmark::State &state) {
KokkosComm::Handle<> h;
auto h = KokkosComm::Communicator<KokkosComm::MpiSpace, Kokkos::DefaultExecutionSpace>::from_raw(
MPI_COMM_WORLD, Kokkos::DefaultExecutionSpace()
)
.value();
if (h.size() != 2) {
state.SkipWithError("benchmark_osu_latency_KokkosComm needs exactly 2 ranks");
}
Expand All @@ -28,14 +34,15 @@ void benchmark_osu_latency_KokkosComm_sendrecv(benchmark::State &state) {
view_type a("A", state.range(0));

while (state.KeepRunning()) {
do_iteration(state, h.mpi_comm(), osu_latency_Kokkos_Comm_sendrecv<Kokkos::DefaultExecutionSpace, view_type>, h, a);
do_iteration(state, h.comm(), osu_latency_Kokkos_Comm_sendrecv<Kokkos::DefaultExecutionSpace, view_type>, h, a);
}
state.counters["bytes"] = a.size() * 2;
}

template <typename Space, typename View>
void osu_latency_Kokkos_Comm_mpi_sendrecv(benchmark::State &, MPI_Comm comm, const Space &space, int rank,
const View &v) {
void osu_latency_Kokkos_Comm_mpi_sendrecv(
benchmark::State &, MPI_Comm comm, const Space &space, int rank, const View &v
) {
if (rank == 0) {
KokkosComm::mpi::send(space, v, 1, 0, comm);
} else if (rank == 1) {
Expand All @@ -56,8 +63,10 @@ void benchmark_osu_latency_Kokkos_Comm_mpi_sendrecv(benchmark::State &state) {
view_type a("A", state.range(0));

while (state.KeepRunning()) {
do_iteration(state, MPI_COMM_WORLD, osu_latency_Kokkos_Comm_mpi_sendrecv<Kokkos::DefaultExecutionSpace, view_type>,
space, rank, a);
do_iteration(
state, MPI_COMM_WORLD, osu_latency_Kokkos_Comm_mpi_sendrecv<Kokkos::DefaultExecutionSpace, view_type>, space,
rank, a
);
}
state.counters["bytes"] = a.size() * 2;
}
Expand All @@ -66,12 +75,16 @@ template <typename View>
void osu_latency_MPI_isendirecv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
MPI_Request sendreq, recvreq;
if (rank == 0) {
MPI_Irecv(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
&recvreq);
MPI_Irecv(
v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
&recvreq
);
MPI_Wait(&recvreq, MPI_STATUS_IGNORE);
} else if (rank == 1) {
MPI_Isend(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm,
&sendreq);
MPI_Isend(
v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm,
&sendreq
);
MPI_Wait(&sendreq, MPI_STATUS_IGNORE);
}
}
Expand All @@ -96,8 +109,10 @@ void benchmark_osu_latency_MPI_isendirecv(benchmark::State &state) {
template <typename View>
void osu_latency_MPI_sendrecv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
if (rank == 0) {
MPI_Recv(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
MPI_STATUS_IGNORE);
MPI_Recv(
v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 1, 0, comm,
MPI_STATUS_IGNORE
);
} else if (rank == 1) {
MPI_Send(v.data(), v.size(), KokkosComm::datatype<KokkosComm::MpiSpace, typename View::value_type>(), 0, 0, comm);
}
Expand Down
4 changes: 2 additions & 2 deletions perf_tests/mpi/test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

// F is a function that takes (state, MPI_Comm, args...)
template <typename F, typename... Args>
void do_iteration(benchmark::State &state, MPI_Comm comm, F &&func, Args... args) {
void do_iteration(benchmark::State& state, MPI_Comm comm, F&& func, Args&&... args) {
using Clock = std::chrono::steady_clock;
using Duration = std::chrono::duration<double>;

auto start = Clock::now();
func(state, comm, args...);
func(state, comm, std::forward<Args>(args)...);
Duration elapsed = Clock::now() - start;

double max_elapsed_second;
Expand Down
Loading
Loading