Skip to content
81 changes: 66 additions & 15 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
nccl_message = deps.nccl_import("the DistributedArray module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
from cupy.cuda.nccl import NcclCommunicator
else:
NcclCommunicator = Any
Expand Down Expand Up @@ -495,14 +495,46 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
def _allgather(self, send_buf, recv_buf=None):
"""Allgather operation
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
if deps.nccl_enabled and self.base_comm_nccl:
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
else:
if recv_buf is None:
return self.base_comm.allgather(send_buf)
self.base_comm.Allgather(send_buf, recv_buf)
return recv_buf

def _send(self, send_buf, dest, count=None, tag=None):
""" Send operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
if count is None:
# assuming sending the whole array
count = send_buf.size
nccl_send(self.base_comm_nccl, send_buf, dest, count)
else:
self.base_comm.Send(send_buf, dest, tag)

def _recv(self, recv_buf=None, source=0, count=None, tag=None):
""" Receive operation
"""
# NCCL must be called with recv_buf. Size cannot be inferred from
# other arguments and thus cannot be dynamically allocated
if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None:
if recv_buf is not None:
if count is None:
# assuming data will take a space of the whole buffer
count = recv_buf.size
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
return recv_buf
else:
raise ValueError("Using recv with NCCL must also supply receiver buffer ")
else:
# MPI allows a receiver buffer to be optional
if recv_buf is None:
return self.base_comm.recv(source=source, tag=tag)
self.base_comm.Recv(buf=recv_buf, source=source, tag=tag)
return recv_buf

def __neg__(self):
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
Expand Down Expand Up @@ -540,6 +572,7 @@ def add(self, dist_array):
self._check_mask(dist_array)
SumArray = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
dtype=self.dtype,
partition=self.partition,
local_shapes=self.local_shapes,
Expand All @@ -566,6 +599,7 @@ def multiply(self, dist_array):

ProductArray = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
dtype=self.dtype,
partition=self.partition,
local_shapes=self.local_shapes,
Expand Down Expand Up @@ -716,6 +750,7 @@ def ravel(self, order: Optional[str] = "C"):
"""
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
arr = DistributedArray(global_shape=np.prod(self.global_shape),
base_comm_nccl=self.base_comm_nccl,
local_shapes=local_shapes,
mask=self.mask,
partition=self.partition,
Expand Down Expand Up @@ -744,41 +779,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
-------
ghosted_array : :obj:`numpy.ndarray`
Ghosted Array

"""
ghosted_array = self.local_array.copy()
ncp = get_module(self.engine)
if cells_front is not None:
total_cells_front = self._allgather(cells_front) + [0]
# cells_front is small array of int. Explicitly use MPI
total_cells_front = self.base_comm.allgather(cells_front) + [0]
# Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
cells_front = total_cells_front[self.rank + 1]
send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis)
recv_shapes = self.local_shapes
if self.rank != 0:
ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array],
axis=self.axis)
if self.rank != self.size - 1:
# from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
# in every dimension except the shape at axis=self.axis
recv_shape = list(recv_shapes[self.rank - 1])
recv_shape[self.axis] = total_cells_front[self.rank]
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
# Transfer of ghost cells can be skipped if len(recv_buf) = 0
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
if len(recv_buf) != 0:
ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
# The skip in sender is to match with what described in receiver
if self.rank != self.size - 1 and len(send_buf) != 0:
if cells_front > self.local_shape[self.axis]:
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
f"should be > {cells_front}: dim({self.axis}) "
f"{self.local_shape[self.axis]} < {cells_front}; "
f"to achieve this use NUM_PROCESSES <= "
f"{max(1, self.global_shape[self.axis] // cells_front)}")
self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis),
dest=self.rank + 1, tag=1)
self._send(send_buf, dest=self.rank + 1, tag=1)
if cells_back is not None:
total_cells_back = self._allgather(cells_back) + [0]
total_cells_back = self.base_comm.allgather(cells_back) + [0]
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
cells_back = total_cells_back[self.rank - 1]
if self.rank != 0:
send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis)
# Same reasoning as sending cell front applied
recv_shapes = self.local_shapes
if self.rank != 0 and len(send_buf) != 0:
if cells_back > self.local_shape[self.axis]:
raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} "
f"should be > {cells_back}: dim({self.axis}) "
f"{self.local_shape[self.axis]} < {cells_back}; "
f"to achieve this use NUM_PROCESSES <= "
f"{max(1, self.global_shape[self.axis] // cells_back)}")
self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis),
dest=self.rank - 1, tag=0)
self._send(send_buf, dest=self.rank - 1, tag=0)
if self.rank != self.size - 1:
ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0),
axis=self.axis)
recv_shape = list(recv_shapes[self.rank + 1])
recv_shape[self.axis] = total_cells_back[self.rank]
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
if len(recv_buf) != 0:
ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
axis=self.axis)
return ghosted_array

def __repr__(self):
Expand Down
2 changes: 2 additions & 0 deletions pylops_mpi/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
if self.Op:
y = DistributedArray(global_shape=self.shape[0],
base_comm=self.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
axis=x.axis,
engine=x.engine,
Expand Down Expand Up @@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
if self.Op:
y = DistributedArray(global_shape=self.shape[1],
base_comm=self.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=x.partition,
axis=x.axis,
engine=x.engine,
Expand Down
4 changes: 2 additions & 2 deletions pylops_mpi/basicoperators/BlockDiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator],
@reshaped(forward=True, stacking=True)
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
mask=self.mask, engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
Expand All @@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
@reshaped(forward=False, stacking=True)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m,
mask=self.mask, engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
Expand Down
20 changes: 10 additions & 10 deletions pylops_mpi/basicoperators/FirstDerivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ def _register_multiplications(
def _matvec(self, x: DistributedArray) -> DistributedArray:
# If Partition.BROADCAST, then convert to Partition.SCATTER
if x.partition is Partition.BROADCAST:
x = DistributedArray.to_dist(x=x.local_array)
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
return self._hmatvec(x)

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
# If Partition.BROADCAST, then convert to Partition.SCATTER
if x.partition is Partition.BROADCAST:
x = DistributedArray.to_dist(x=x.local_array)
x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl)
return self._hrmatvec(x)

@reshaped
def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_back=1)
y_forward = ghosted_x[1:] - ghosted_x[:-1]
Expand All @@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0
if self.rank == self.size - 1:
Expand All @@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_front=1)
y_backward = ghosted_x[1:] - ghosted_x[:-1]
Expand All @@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0
ghosted_x = x.add_ghost_cells(cells_back=1)
Expand All @@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1)
y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2])
Expand All @@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0

Expand Down Expand Up @@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2)
y_centered = (
Expand All @@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray:
@reshaped
def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes,
y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes,
axis=x.axis, engine=x.engine, dtype=self.dtype)
y[:] = 0
ghosted_x = x.add_ghost_cells(cells_back=4)
Expand Down
Loading