Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pylops_mpi/Distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from mpi4py import MPI
from pylops.utils import NDArray
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
from pylops_mpi.utils._mpi import (mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, mpi_sendrecv,
_prepare_allgather_inputs, _unroll_allgather_recv)
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, mpi_sendrecv
from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
from pylops_mpi.utils import deps

cupy_message = pylops_deps.cupy_import("the DistributedArray module")
Expand Down Expand Up @@ -32,7 +32,6 @@ class DistributedMixIn:
MPI installation is not available).

"""

def _allreduce(self,
base_comm: MPI.Comm,
base_comm_nccl: NcclCommunicatorType,
Expand Down
2 changes: 1 addition & 1 deletion pylops_mpi/signalprocessing/Fredholm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
engine=y.engine)).ravel()
return y

def _rmatvec(self, x: NDArray) -> NDArray:
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
Expand Down
40 changes: 28 additions & 12 deletions pylops_mpi/utils/_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,27 @@
"mpi_sendrecv"
]

from typing import Optional
from typing import List, Optional

from mpi4py import MPI
from pylops.utils import NDArray
from pylops.utils.backend import get_module
from pylops_mpi.utils import deps
from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv


def mpi_allgather(base_comm: MPI.Comm,
send_buf: NDArray,
recv_buf: Optional[NDArray] = None,
engine: str = "numpy",
) -> NDArray:
"""MPI_Allallgather/allallgather
) -> List[NDArray]:
"""MPI_Allgather/allgather

Dispatch allgather routine based on type of input and availability of
CUDA-Aware MPI
Dispatch the appropriate allgather routine based on buffer sizes and
CUDA-aware MPI availability.

If all ranks provide buffers of equal size, the standard `Allgather`
collective is used. Otherwise, `Allgatherv` is invoked to handle
variable-sized buffers.

Parameters
----------
Expand All @@ -40,16 +43,29 @@ def mpi_allgather(base_comm: MPI.Comm,

Returns
-------
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
A buffer containing the gathered data from all ranks.
recv_buf : :obj:`list`
A list of arrays containing the gathered data from all ranks.

"""
if deps.cuda_aware_mpi_enabled or engine == "numpy":
ncp = get_module(engine)
send_shapes = base_comm.allgather(send_buf.shape)
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine)
recv_buffer_to_use = recv_buf if recv_buf else padded_recv
_mpi_calls(base_comm, "Allgather", padded_send, recv_buffer_to_use, engine=engine)
return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes)
recvcounts = base_comm.allgather(send_buf.size)
recv_buf = recv_buf if recv_buf else ncp.zeros(sum(recvcounts), dtype=send_buf.dtype)
if len(set(send_shapes)) == 1:
_mpi_calls(base_comm, "Allgather", ncp.ascontiguousarray(send_buf), recv_buf, engine=engine)
return [chunk.reshape(send_shapes[0]) for chunk in ncp.split(recv_buf, base_comm.size)]
else:
# displs represent the starting offsets in recv_buf where data from each rank will be placed
displs = [0]
for i in range(1, len(recvcounts)):
displs.append(displs[i - 1] + recvcounts[i - 1])
_mpi_calls(base_comm, "Allgatherv", ncp.ascontiguousarray(send_buf),
[recv_buf, recvcounts, displs, MPI._typedict[send_buf.dtype.char]], engine=engine)
return [
recv_buf[displs[i]:displs[i] + recvcounts[i]].reshape(send_shapes[i])
for i in range(base_comm.size)
]
else:
# CuPy with non-CUDA-aware MPI
if recv_buf is None:
Expand Down
Loading