Skip to content
Merged
20 changes: 17 additions & 3 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
nccl_message = deps.nccl_import("the DistributedArray module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
from cupy.cuda.nccl import NcclCommunicator
else:
NcclCommunicator = Any
Expand Down Expand Up @@ -500,7 +500,15 @@ def _allgather(self, send_buf, recv_buf=None):
"""Allgather operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
if hasattr(send_buf, "shape"):
send_shapes = self.base_comm.allgather(send_buf.shape)
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
# TODO: Should we ignore recv_buf completely in this case ?
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
else:
# still works for a send_buf whose type is a tuple for _nccl_local_shapes
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
else:
if recv_buf is None:
return self.base_comm.allgather(send_buf)
Expand All @@ -511,7 +519,13 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
"""Allgather operation with subcommunicator
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
if hasattr(send_buf, "shape"):
send_shapes = self.base_comm.allgather(send_buf.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you sure is not self.sub_comm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, you are right. Strange enough, no test case catches that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I would think here

xmaskedloc = arr.asarray(masked=True)
the test should fail... the original code (in current main branch) looks correct but your previous commit should have called this
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
in the .asarray call... anyways, now it is correct 😄

(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
else:
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
else:
if recv_buf is None:
return self.sub_comm.allgather(send_buf)
Expand Down
12 changes: 1 addition & 11 deletions pylops_mpi/signalprocessing/Fredholm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,5 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()

# gather results
recv = y._allgather(y1)
# TODO: current of _allgather will call non-buffered MPI-AllGather (sub-optimal for CuPy+MPI)
# which returns a list (not flatten) and does not require unrolling
if self.usematmul and isinstance(recv, ncp.ndarray) :
# unrolling
chunk_size = self.ny * self.nz
num_partition = (len(recv) + chunk_size - 1) // chunk_size
recv = ncp.vstack([recv[i * chunk_size: (i + 1) * chunk_size].reshape(self.nz, self.ny).T for i in range(num_partition)])
else:
recv = ncp.vstack(recv)
y[:] = recv.ravel()
y[:] = ncp.vstack(y._allgather(y1)).ravel()
return y
108 changes: 79 additions & 29 deletions pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
"nccl_bcast",
"nccl_asarray",
"nccl_send",
"nccl_recv"
"nccl_recv",
"_prepare_nccl_allgather_inputs",
"_unroll_nccl_allgather_recv"
]

from enum import IntEnum
Expand Down Expand Up @@ -251,61 +253,109 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None:
)


def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
"""Global view of the array
def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> tuple[cp.ndarray, cp.ndarray]:
""" Preparing the send_buf and recv_buf for the NCCL allgather (nccl_allgather)

Gather all local GPU arrays into a single global array via NCCL all-gather.
NCCL's allGather requires the sending buffer to have the same size for every device.
Therefore, the padding is required when the array is not evenly partitioned across
all the ranks. The padding is applied such that the sending buffer has the size of
each dimension corresponding to the max possible size of that dimension.

Receiver buff (recv_buf) will have the size n_rank * send_buf.size

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for collective communication.
local_array : :obj:`cupy.ndarray`
The local array on the current GPU.
local_shapes : :obj:`list`
A list of shapes for each GPU local array (used to trim padding).
axis : :obj:`int`
The axis along which to concatenate the gathered arrays.
send_buf : :obj:`cupy.ndarray` or array-like
The data buffer from the local GPU to be sent for allgather.
send_buf_shapes: :obj:`list`
A list of shapes for each GPU send_buf (used to calculate padding size)

Returns
-------
final_array : :obj:`cupy.ndarray`
Global array gathered from all GPUs and concatenated along `axis`.
tuple[send_buf, recv_buf]: :obj:`tuple`
A tuple of (send_buf, recv_buf) will an appropriate size, shape and dtype for NCCL allgather

Notes
-----
NCCL's allGather requires the sending buffer to have the same size for every device.
Therefore, the padding is required when the array is not evenly partitioned across
all the ranks. The padding is applied such that the sending buffer has the size of
each dimension corresponding to the max possible size of that dimension.
"""
sizes_each_dim = list(zip(*local_shapes))

sizes_each_dim = list(zip(*send_buf_shapes))
send_shape = tuple(map(max, sizes_each_dim))
pad_size = [
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape)
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
]

send_buf = cp.pad(
local_array, pad_size, mode="constant", constant_values=0
send_buf, pad_size, mode="constant", constant_values=0
)

# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
ndev = len(local_shapes)
ndev = len(send_buf_shapes)
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
nccl_allgather(nccl_comm, send_buf, recv_buf)

return (send_buf, recv_buf)


def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
""" Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays

Each GPU may send array with a different shape, so the return type has to be a list of array
instead of the concatenated array.

Parameters
----------
recv_buf: :obj:`cupy.ndarray` or array-like
The data buffer returned from nccl_allgather call
padded_send_buf_shape: :obj:`tuple`:int
The size of send_buf after padding used in nccl_allgather
send_buf_shapes: :obj:`list`
A list of original shapes for each GPU send_buf prior to padding

Returns
-------
chunks: :obj:`list`
A list of `cupy.ndarray` from each GPU with the padded element removed
"""

ndev = len(send_buf_shapes)
# extract an individual array from each device
chunk_size = np.prod(send_shape)
chunk_size = np.prod(padded_send_buf_shape)
chunks = [
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
]

# Remove padding from each array: the padded value may appear somewhere
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
for i in range(ndev):
slicing = tuple(slice(0, end) for end in local_shapes[i])
chunks[i] = chunks[i].reshape(send_shape)[slicing]
slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]

return chunks


def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
"""Global view of the array

Gather all local GPU arrays into a single global array via NCCL all-gather.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for collective communication.
local_array : :obj:`cupy.ndarray`
The local array on the current GPU.
local_shapes : :obj:`list`
A list of shapes for each GPU local array (used to trim padding).
axis : :obj:`int`
The axis along which to concatenate the gathered arrays.

Returns
-------
final_array : :obj:`cupy.ndarray`
Global array gathered from all GPUs and concatenated along `axis`.
"""

(send_buf, recv_buf) = _prepare_nccl_allgather_inputs(local_array, local_shapes)
nccl_allgather(nccl_comm, send_buf, recv_buf)
chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes)

# combine back to single global array
return cp.concatenate(chunks, axis=axis)

Expand Down