Skip to content
72 changes: 59 additions & 13 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
nccl_message = deps.nccl_import("the DistributedArray module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
from cupy.cuda.nccl import NcclCommunicator
else:
NcclCommunicator = Any
Expand Down Expand Up @@ -503,6 +503,35 @@ def _allgather(self, send_buf, recv_buf=None):
self.base_comm.Allgather(send_buf, recv_buf)
return recv_buf

def _send(self, send_buf, dest, count=None, tag=None):
""" Send operation
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
if count is None:
# assuming sending the whole array
count = send_buf.size
nccl_send(self.base_comm_nccl, send_buf, dest, count)
else:
self.base_comm.Send(send_buf, dest, tag)

def _recv(self, recv_buf=None, source=0, count=None, tag=None):
""" Receive operation
"""
# NCCL must be called with recv_buf. Size cannot be inferred from
# other arguments and thus cannot dynamically allocated
if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None:
if count is None:
# assuming data will take a space of the whole buffer
count = recv_buf.size
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
return recv_buf
else:
# MPI allows a receiver buffer to be optional
if recv_buf is None:
return self.base_comm.recv(source=source, tag=tag)
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
return recv_buf

def __neg__(self):
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
Expand Down Expand Up @@ -747,38 +776,55 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,

"""
ghosted_array = self.local_array.copy()
ncp = get_module(getattr(self, "engine"))
if cells_front is not None:
total_cells_front = self._allgather(cells_front) + [0]
# cells_front is small array of int. Explicitly use MPI
total_cells_front = self.base_comm.allgather(cells_front) + [0]
# Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
cells_front = total_cells_front[self.rank + 1]
send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis)
recv_shapes = self.local_shapes
if self.rank != 0:
ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array],
axis=self.axis)
if self.rank != self.size - 1:
# from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
# in every dimension except the shape at axis=self.axis
recv_shape = list(recv_shapes[self.rank - 1])
recv_shape[self.axis] = total_cells_front[self.rank]
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
# Some communication can skip if len(recv_buf) = 0
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
if len(recv_buf) != 0:
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
# The skip in sender is to match with what described in receiver
if self.rank != self.size - 1 and len(send_buf) != 0:
if cells_front > self.local_shape[self.axis]:
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
f"should be > {cells_front}: dim({self.axis}) "
f"{self.local_shape[self.axis]} < {cells_front}; "
f"to achieve this use NUM_PROCESSES <= "
f"{max(1, self.global_shape[self.axis] // cells_front)}")
self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis),
dest=self.rank + 1, tag=1)
self._send(send_buf, dest=self.rank + 1, tag=1)
if cells_back is not None:
total_cells_back = self._allgather(cells_back) + [0]
total_cells_back = self.base_comm.allgather(cells_back) + [0]
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
cells_back = total_cells_back[self.rank - 1]
if self.rank != 0:
send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis)
# Same reasoning as sending cell front applied
recv_shapes = self.local_shapes
if self.rank != 0 and len(send_buf) != 0:
if cells_back > self.local_shape[self.axis]:
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
f"should be > {cells_back}: dim({self.axis}) "
f"{self.local_shape[self.axis]} < {cells_back}; "
f"to achieve this use NUM_PROCESSES <= "
f"{max(1, self.global_shape[self.axis] // cells_back)}")
self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis),
dest=self.rank - 1, tag=0)
self._send(send_buf, dest=self.rank - 1, tag=0)
if self.rank != self.size - 1:
ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0),
axis=self.axis)
recv_shape = list(recv_shapes[self.rank + 1])
recv_shape[self.axis] = total_cells_back[self.rank]
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
if len(recv_buf) != 0:
ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
axis=self.axis)
return ghosted_array

def __repr__(self):
Expand Down
2 changes: 2 additions & 0 deletions pylops_mpi/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
if self.Op:
y = DistributedArray(global_shape=self.shape[0],
base_comm=self.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
axis=x.axis,
engine=x.engine,
Expand Down Expand Up @@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
if self.Op:
y = DistributedArray(global_shape=self.shape[1],
base_comm=self.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
axis=x.axis,
engine=x.engine,
Expand Down
4 changes: 2 additions & 2 deletions pylops_mpi/basicoperators/BlockDiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator],
@reshaped(forward=True, stacking=True)
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
mask=self.mask, engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
Expand All @@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
@reshaped(forward=False, stacking=True)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
mask=self.mask, engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
Expand Down
20 changes: 10 additions & 10 deletions pylops_mpi/basicoperators/FirstDerivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ def _register_multiplications(
def _matvec(self, x: DistributedArray) -> DistributedArray:
# If Partition.BROADCAST, then convert to Partition.SCATTER
if x.partition is Partition.BROADCAST:
x = DistributedArray.to_dist(x=x.local_array)
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
return self._hmatvec(x)

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
# If Partition.BROADCAST, then convert to Partition.SCATTER
if x.partition is Partition.BROADCAST:
x = DistributedArray.to_dist(x=x.local_array)
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
return self._hrmatvec(x)

@reshaped
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_back=1)
y_forward = ghosted_x[1:] - ghosted_x[:-1]
Expand All @@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0
if self.rank == self.size - 1:
Expand All @@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_front=1)
y_backward = ghosted_x[1:] - ghosted_x[:-1]
Expand All @@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0
ghosted_x = x.add_ghost_cells(cells_back=1)
Expand All @@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2])
Expand All @@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0

Expand Down Expand Up @@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2)
y_centered = (
Expand All @@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0
ghosted_x = x.add_ghost_cells(cells_back=4)
Expand Down
18 changes: 15 additions & 3 deletions pylops_mpi/basicoperators/VStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pylops import LinearOperator
from pylops.utils import DTypeLike
from pylops.utils.backend import get_module
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils

from pylops_mpi import (
MPILinearOperator,
Expand All @@ -15,6 +16,13 @@
StackedDistributedArray
)
from pylops_mpi.utils.decorators import reshaped
from pylops_mpi.utils import deps

cupy_message = pylops_deps.cupy_import("the VStack module")
nccl_message = deps.nccl_import("the VStack module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import nccl_allreduce


class MPIVStack(MPILinearOperator):
Expand Down Expand Up @@ -121,7 +129,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
f"Got {x.partition} instead...")
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
# the output y should use NCCL if the operand x uses it
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
Expand All @@ -132,13 +141,16 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
@reshaped(forward=False, stacking=True)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST,
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
y1 = ncp.sum(ncp.vstack(y1), axis=0)
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
if deps.nccl_enabled and x.base_comm_nccl:
y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM)
else:
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
return y


Expand Down
58 changes: 57 additions & 1 deletion pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"nccl_allgather",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be good to add all nccl_* methods to the Utils section of https://github.com/PyLops/pylops-mpi/blob/main/docs/source/api/index.rst

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I can do that. Maybe in the other PR ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. If it is very small like this we can go for the same PR, if it is something a bit more consistent like the changes you made previously, it is good practice to have a separate documentation-only PR😄

"nccl_allreduce",
"nccl_bcast",
"nccl_asarray"
"nccl_asarray",
"nccl_send",
"nccl_recv"
]

from enum import IntEnum
Expand Down Expand Up @@ -286,3 +288,57 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
chunks[i] = chunks[i].reshape(send_shape)[slicing]
# combine back to single global array
return cp.concatenate(chunks, axis=axis)


def nccl_send(nccl_comm, send_buf, dest, count):
"""NCCL equivalent of MPI_Send. Sends a specified number of elements
from the buffer to a destination GPU device.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for point-to-point communication.
send_buf : :obj:`cupy.ndarray`
The array containing data to send.
dest: :obj:`int`
The rank of the destination GPU device.
count : :obj:`int`
Number of elements to send from `send_buf`.

Returns
-------
None
"""
nccl_comm.send(send_buf.data.ptr,
count,
cupy_to_nccl_dtype[str(send_buf.dtype)],
dest,
cp.cuda.Stream.null.ptr
)


def nccl_recv(nccl_comm, recv_buf, source, count=None):
"""NCCL equivalent of MPI_Recv. Receives data from a source GPU device
into the given buffer.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for point-to-point communication.
recv_buf : :obj:`cupy.ndarray`
The array to store the received data.
source : :obj:`int`
The rank of the source GPU device.
count : :obj:`int`, optional
Number of elements to receive.

Returns
-------
None
"""
nccl_comm.recv(recv_buf.data.ptr,
count,
cupy_to_nccl_dtype[str(recv_buf.dtype)],
source,
cp.cuda.Stream.null.ptr
)
1 change: 1 addition & 0 deletions pylops_mpi/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def wrapper(self, x: DistributedArray):
local_shapes = None
global_shape = getattr(self, "dims")
arr = DistributedArray(global_shape=global_shape,
base_comm_nccl=x.base_comm_nccl,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are changing this, I think it would be safe to also pass base_comm_nccl=x.base_comm.. I think in the past this never led to any issue as we probably always used MPI.COMM_WORLD but it's good not to assume this to be always the case 😄 (@rohanbabbar04, agree?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_comm_nccl=x.base_comm_nccl,
base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl,

local_shapes=local_shapes, axis=0,
engine=x.engine, dtype=x.dtype)
arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape)))
Expand Down
Loading