diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 50e54d3b..6e67e1fc 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -14,7 +14,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split + from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -495,7 +495,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): def _allgather(self, send_buf, recv_buf=None): """Allgather operation """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if deps.nccl_enabled and self.base_comm_nccl: return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) else: if recv_buf is None: @@ -503,6 +503,38 @@ def _allgather(self, send_buf, recv_buf=None): self.base_comm.Allgather(send_buf, recv_buf) return recv_buf + def _send(self, send_buf, dest, count=None, tag=None): + """ Send operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if count is None: + # assuming sending the whole array + count = send_buf.size + nccl_send(self.base_comm_nccl, send_buf, dest, count) + else: + self.base_comm.Send(send_buf, dest, tag) + + def _recv(self, recv_buf=None, source=0, count=None, tag=None): + """ Receive operation + """ + # NCCL must be called with recv_buf. Size cannot be inferred from + # other arguments and thus cannot be dynamically allocated + if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: + if recv_buf is not None: + if count is None: + # assuming data will take a space of the whole buffer + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf + else: + raise ValueError("Using recv with NCCL must also supply receiver buffer ") + else: + # MPI allows a receiver buffer to be optional + if recv_buf is None: + return self.base_comm.recv(source=source, tag=tag) + self.base_comm.Recv(buf=recv_buf, source=source, tag=tag) + return recv_buf + def __neg__(self): arr = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, @@ -540,6 +572,7 @@ def add(self, dist_array): self._check_mask(dist_array) SumArray = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, + base_comm_nccl=self.base_comm_nccl, dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, @@ -566,6 +599,7 @@ def multiply(self, dist_array): ProductArray = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, + base_comm_nccl=self.base_comm_nccl, dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, @@ -716,6 +750,8 @@ def ravel(self, order: Optional[str] = "C"): """ local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes] arr = DistributedArray(global_shape=np.prod(self.global_shape), + base_comm=self.base_comm, + base_comm_nccl=self.base_comm_nccl, local_shapes=local_shapes, mask=self.mask, partition=self.partition, @@ -744,41 +780,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, ------- ghosted_array : :obj:`numpy.ndarray` Ghosted Array - """ ghosted_array = self.local_array.copy() + ncp = get_module(self.engine) if cells_front is not None: - total_cells_front = self._allgather(cells_front) + [0] + # cells_front is small array of int. Explicitly use MPI + total_cells_front = self.base_comm.allgather(cells_front) + [0] # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) cells_front = total_cells_front[self.rank + 1] + send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis) + recv_shapes = self.local_shapes if self.rank != 0: - ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array], - axis=self.axis) - if self.rank != self.size - 1: + # from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1) + # in every dimension except the shape at axis=self.axis + recv_shape = list(recv_shapes[self.rank - 1]) + recv_shape[self.axis] = total_cells_front[self.rank] + recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) + # Transfer of ghost cells can be skipped if len(recv_buf) = 0 + # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory + if len(recv_buf) != 0: + ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis) + # The skip in sender is to match with what described in receiver + if self.rank != self.size - 1 and len(send_buf) != 0: if cells_front > self.local_shape[self.axis]: raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} " f"should be > {cells_front}: dim({self.axis}) " f"{self.local_shape[self.axis]} < {cells_front}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_front)}") - self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis), - dest=self.rank + 1, tag=1) + self._send(send_buf, dest=self.rank + 1, tag=1) if cells_back is not None: - total_cells_back = self._allgather(cells_back) + [0] + total_cells_back = self.base_comm.allgather(cells_back) + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) cells_back = total_cells_back[self.rank - 1] - if self.rank != 0: + send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis) + # Same reasoning as sending cell front applied + recv_shapes = self.local_shapes + if self.rank != 0 and len(send_buf) != 0: if cells_back > self.local_shape[self.axis]: raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} " f"should be > {cells_back}: dim({self.axis}) " f"{self.local_shape[self.axis]} < {cells_back}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_back)}") - self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis), - dest=self.rank - 1, tag=0) + self._send(send_buf, dest=self.rank - 1, tag=0) if self.rank != self.size - 1: - ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0), - axis=self.axis) + recv_shape = list(recv_shapes[self.rank + 1]) + recv_shape[self.axis] = total_cells_back[self.rank] + recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) + if len(recv_buf) != 0: + ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0), + axis=self.axis) return ghosted_array def __repr__(self): diff --git a/pylops_mpi/LinearOperator.py b/pylops_mpi/LinearOperator.py index 266e55fe..49077325 100644 --- a/pylops_mpi/LinearOperator.py +++ b/pylops_mpi/LinearOperator.py @@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if self.Op: y = DistributedArray(global_shape=self.shape[0], base_comm=self.base_comm, + base_comm_nccl=x.base_comm_nccl, partition=x.partition, axis=x.axis, engine=x.engine, @@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if self.Op: y = DistributedArray(global_shape=self.shape[1], base_comm=self.base_comm, + base_comm_nccl=x.base_comm_nccl, partition=x.partition, axis=x.axis, engine=x.engine, diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index 28511105..27d9b813 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator], @reshaped(forward=True, stacking=True) def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m, + y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m, mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): diff --git a/pylops_mpi/basicoperators/FirstDerivative.py b/pylops_mpi/basicoperators/FirstDerivative.py index 5adbe284..b25f43d7 100644 --- a/pylops_mpi/basicoperators/FirstDerivative.py +++ b/pylops_mpi/basicoperators/FirstDerivative.py @@ -129,19 +129,19 @@ def _register_multiplications( def _matvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hrmatvec(x) @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=1) y_forward = ghosted_x[1:] - ghosted_x[:-1] @@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: @@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1) y_backward = ghosted_x[1:] - ghosted_x[:-1] @@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=1) @@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2]) @@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 @@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2) y_centered = ( @@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=4) diff --git a/pylops_mpi/basicoperators/SecondDerivative.py b/pylops_mpi/basicoperators/SecondDerivative.py index 6c4fb961..bfe09b78 100644 --- a/pylops_mpi/basicoperators/SecondDerivative.py +++ b/pylops_mpi/basicoperators/SecondDerivative.py @@ -112,20 +112,20 @@ def _register_multiplications( def _matvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hrmatvec(x) @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=2) y_forward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == self.size - 1: @@ -136,7 +136,8 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: y[:-2] += x[:-2] @@ -162,8 +163,8 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2) y_backward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == 0: @@ -174,8 +175,8 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) y_backward = ghosted_x[2:] @@ -201,8 +202,8 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == 0: @@ -221,8 +222,8 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) y_centered = ghosted_x[1:-1] diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index f869a9ad..58581565 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -6,6 +6,7 @@ from pylops import LinearOperator from pylops.utils import DTypeLike from pylops.utils.backend import get_module +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi import ( MPILinearOperator, @@ -15,6 +16,13 @@ StackedDistributedArray ) from pylops_mpi.utils.decorators import reshaped +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the VStack module") +nccl_message = deps.nccl_import("the VStack module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import nccl_allreduce class MPIVStack(MPILinearOperator): @@ -121,7 +129,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}" f"Got {x.partition} instead...") - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + # the output y should use NCCL if the operand x uses it + y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -132,13 +141,16 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, + y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) + if deps.nccl_enabled and x.base_comm_nccl: + y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM) + else: + y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) return y diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index c3b02b71..d183fc58 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -4,7 +4,9 @@ "nccl_allgather", "nccl_allreduce", "nccl_bcast", - "nccl_asarray" + "nccl_asarray", + "nccl_send", + "nccl_recv" ] from enum import IntEnum @@ -213,10 +215,6 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None: The index in the array to be broadcasted. value : :obj:`scalar` The value to broadcast (only used by the root GPU, rank 0). - - Returns - ------- - None """ if nccl_comm.rank_id() == 0: local_array[index] = value @@ -286,3 +284,49 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: chunks[i] = chunks[i].reshape(send_shape)[slicing] # combine back to single global array return cp.concatenate(chunks, axis=axis) + + +def nccl_send(nccl_comm, send_buf, dest, count): + """NCCL equivalent of MPI_Send. Sends a specified number of elements + from the buffer to a destination GPU device. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for point-to-point communication. + send_buf : :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination GPU device. + count : :obj:`int` + Number of elements to send from `send_buf`. + """ + nccl_comm.send(send_buf.data.ptr, + count, + cupy_to_nccl_dtype[str(send_buf.dtype)], + dest, + cp.cuda.Stream.null.ptr + ) + + +def nccl_recv(nccl_comm, recv_buf, source, count=None): + """NCCL equivalent of MPI_Recv. Receives data from a source GPU device + into the given buffer. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for point-to-point communication. + recv_buf : :obj:`cupy.ndarray` + The array to store the received data. + source : :obj:`int` + The rank of the source GPU device. + count : :obj:`int`, optional + Number of elements to receive. + """ + nccl_comm.recv(recv_buf.data.ptr, + count, + cupy_to_nccl_dtype[str(recv_buf.dtype)], + source, + cp.cuda.Stream.null.ptr + ) diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index 457b559b..21b16906 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -54,6 +54,8 @@ def wrapper(self, x: DistributedArray): local_shapes = None global_shape = getattr(self, "dims") arr = DistributedArray(global_shape=global_shape, + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, local_shapes=local_shapes, axis=0, engine=x.engine, dtype=x.dtype) arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape))) diff --git a/tests_nccl/test_blockdiag_nccl.py b/tests_nccl/test_blockdiag_nccl.py new file mode 100644 index 00000000..8c278481 --- /dev/null +++ b/tests_nccl/test_blockdiag_nccl.py @@ -0,0 +1,118 @@ +"""Test the MPIBlockDiag and MPIStackedBlockDiag classes + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_blockdiag_nccl.py --with-mpi + +This file employs the same test sets as test_blockdiag under NCCL environment +""" +from mpi4py import MPI +import numpy as np +import cupy as cp +from numpy.testing import assert_allclose +import pytest + +import pylops +import pylops_mpi +from pylops_mpi.utils.dottest import dottest +from pylops_mpi.utils._nccl import initialize_nccl_comm + +nccl_comm = initialize_nccl_comm() + +par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64} +# par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128} +par2 = {'ny': 301, 'nx': 101, 'dtype': np.float64} +# par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128} + +np.random.seed(42) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_blockdiag_nccl(par): + """Test the MPIBlockDiag with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + Op = pylops.MatrixMult(A=((rank + 1) * cp.ones(shape=(par['ny'], par['nx']))).astype(par['dtype'])) + BDiag_MPI = pylops_mpi.MPIBlockDiag(ops=[Op, ], ) + + x = pylops_mpi.DistributedArray(global_shape=size * par['nx'], + base_comm_nccl=nccl_comm, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + y = pylops_mpi.DistributedArray(global_shape=size * par['ny'], + base_comm_nccl=nccl_comm, + dtype=par['dtype'], + engine="cupy") + y[:] = cp.ones(shape=par['ny'], dtype=par['dtype']) + y_global = y.asarray() + + # Forward + x_mat = BDiag_MPI @ x + # Adjoint + y_rmat = BDiag_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.DistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + # Dot test + dottest(BDiag_MPI, x, y, size * par['ny'], size * par['nx']) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in range(size)] + BDiag = pylops.BlockDiag(ops=ops) + + x_mat_np = BDiag @ x_global.get() + y_rmat_np = BDiag.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_stacked_blockdiag_nccl(par): + """Tests for MPIStackedBlogDiag with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + Op = pylops.MatrixMult(A=((rank + 1) * cp.ones(shape=(par['ny'], par['nx']))).astype(par['dtype'])) + BDiag_MPI = pylops_mpi.MPIBlockDiag(ops=[Op, ], ) + FirstDeriv_MPI = pylops_mpi.MPIFirstDerivative(dims=(par['ny'], par['nx']), dtype=par['dtype']) + StackedBDiag_MPI = pylops_mpi.MPIStackedBlockDiag(ops=[BDiag_MPI, FirstDeriv_MPI]) + + dist1 = pylops_mpi.DistributedArray(global_shape=size * par['nx'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist2[:] = cp.ones(dist2.local_shape, dtype=par['dtype']) + x = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2]) + x_global = x.asarray() + + dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist2[:] = cp.ones(dist2.local_shape, dtype=par['dtype']) + y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2]) + y_global = y.asarray() + + # Forward + x_mat = StackedBDiag_MPI @ x + # Adjoint + y_rmat = StackedBDiag_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.StackedDistributedArray) + assert isinstance(y_rmat, pylops_mpi.StackedDistributedArray) + # Dot test + dottest(StackedBDiag_MPI, x, y, size * par['ny'] + par['nx'] * par['ny'], size * par['nx'] + par['nx'] * par['ny']) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in range(size)] + BDiag = pylops.BlockDiag(ops=ops) + FirstDeriv = pylops.FirstDerivative(dims=(par['ny'], par['nx']), axis=0, dtype=par['dtype']) + BDiag_final = pylops.BlockDiag([BDiag, FirstDeriv]) + x_mat_np = BDiag_final @ x_global.get() + y_rmat_np = BDiag_final.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) diff --git a/tests_nccl/test_derivative_nccl.py b/tests_nccl/test_derivative_nccl.py new file mode 100644 index 00000000..e77348a9 --- /dev/null +++ b/tests_nccl/test_derivative_nccl.py @@ -0,0 +1,681 @@ +"""Test the derivative classes + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_derivative_nccl.py --with-mpi + +This file employs the same test sets as test_derivative under NCCL environment +""" + +import numpy as np +import cupy as cp +from mpi4py import MPI +from numpy.testing import assert_allclose +import pytest + +import pylops +import pylops_mpi +from pylops_mpi.utils.dottest import dottest +from pylops_mpi.utils._nccl import initialize_nccl_comm + +nccl_comm = initialize_nccl_comm() + +np.random.seed(42) +rank = MPI.COMM_WORLD.Get_rank() +size = MPI.COMM_WORLD.Get_size() + +par1 = { + "nz": 600, + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par1b = { + "nz": 600, + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par1j = { +# "nz": 600, +# "dz": 1.0, +# "edge": False, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par1e = { + "nz": 600, + "dz": 1.0, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par2 = { + "nz": (100, 151), + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par2b = { + "nz": (100, 151), + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par2j = { +# "nz": (100, 151), +# "dz": 1.0, +# "edge": False, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par2e = { + "nz": (100, 151), + "dz": 1.0, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par3 = { + "nz": (101, 51, 100), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par3b = { + "nz": (101, 51, 100), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par3j = { +# "nz": (101, 51, 100), +# "dz": 0.4, +# "edge": True, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par3e = { + "nz": (101, 51, 100), + "dz": 0.4, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par4 = { + "nz": (79, 101, 50), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par4b = { + "nz": (79, 101, 50), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par4j = { +# "nz": (79, 101, 50), +# "dz": 0.4, +# "edge": True, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par4e = { + "nz": (79, 101, 50), + "dz": 0.4, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par5 = { + "n": (120, 101, 60), + "axes": (0, 1, 2), + "weights": (0.7, 0.7, 0.7), + "sampling": (1, 1, 1), + "edge": False, + "dtype": np.float64, +} + +par5e = { + "n": (120, 101, 60), + "axes": (-1, -2, -3), + "weights": (0.7, 0.7, 0.7), + "sampling": (1, 1, 1), + "edge": True, + "dtype": np.float64, +} + +par6 = { + "n": (79, 60, 101), + "axes": (0, 1, 2), + "weights": (1, 1, 1), + "sampling": (0.4, 0.4, 0.4), + "edge": False, + "dtype": np.float64, +} + +par6e = { + "n": (79, 60, 101), + "axes": (-1, -2, -3), + "weights": (1, 1, 1), + "sampling": (0.4, 0.4, 0.4), + "edge": True, + "dtype": np.float64, +} + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_first_derivative_forward(par): + """MPIFirstDerivative operator (forward stencil)""" + Fop_MPI = pylops_mpi.MPIFirstDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Fop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Fop_MPI.H @ x + y_adj = y_adj_dist.asarray() + + # Dot test + dottest(Fop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Fop = pylops.FirstDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Fop_MPI.shape == Fop.shape + y_np = Fop @ x_global.get() + y_adj_np = Fop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_first_derivative_backward(par): + """MPIFirstDerivative operator (backward stencil)""" + Fop_MPI = pylops_mpi.MPIFirstDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Fop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Fop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Fop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Fop = pylops.FirstDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Fop_MPI.shape == Fop.shape + y_np = Fop @ x_global.get() + y_adj_np = Fop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_first_derivative_centered(par): + """MPIFirstDerivative operator (centered stencil)""" + for order in [3, 5]: + Fop_MPI = pylops_mpi.MPIFirstDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="centered", + edge=par["edge"], + order=order, + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Fop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Fop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Fop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Fop = pylops.FirstDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="centered", + edge=par["edge"], + order=order, + dtype=par["dtype"], + ) + assert Fop_MPI.shape == Fop.shape + y_np = Fop @ x_global.get() + y_adj_np = Fop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_second_derivative_forward(par): + """MPISecondDerivative operator (forward stencil)""" + Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Sop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Sop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Sop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Sop = pylops.SecondDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Sop_MPI.shape == Sop.shape + y_np = Sop @ x_global.get() + y_adj_np = Sop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_second_derivative_backward(par): + """MPISecondDerivative operator (backward stencil)""" + Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Sop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Sop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Sop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Sop = pylops.SecondDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Sop_MPI.shape == Sop.shape + y_np = Sop @ x_global.get() + y_adj_np = Sop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_second_derivative_centered(par): + """MPISecondDerivative operator (centered stencil)""" + Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="centered", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Sop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Sop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Sop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Sop = pylops.SecondDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="centered", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Sop_MPI.shape == Sop.shape + y_np = Sop @ x_global.get() + y_adj_np = Sop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par5), (par5e), (par6), (par6e)]) +def test_laplacian(par): + """MPILaplacian Operator""" + for kind in ["forward", "backward", "centered"]: + Lop_MPI = pylops_mpi.basicoperators.MPILaplacian( + dims=par["n"], + axes=par["axes"], + weights=par["weights"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["n"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Lop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Lop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Lop_MPI, x, y_dist, np.prod(par["n"]), np.prod(par["n"])) + + if rank == 0: + Lop = pylops.Laplacian( + dims=par["n"], + axes=par["axes"], + weights=par["weights"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + assert Lop_MPI.shape == Lop.shape + y_np = Lop @ x_global.get() + y_adj_np = Lop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par5), (par5e), (par6), (par6e)]) +def test_gradient(par): + """MPIGradient Operator""" + for kind in ["forward", "backward", "centered"]: + Gop_MPI = pylops_mpi.basicoperators.MPIGradient( + dims=par["n"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + x_fwd = pylops_mpi.DistributedArray( + global_shape=np.prod(par["n"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_fwd[:] = cp.random.normal(rank, 10, x_fwd.local_shape) + x_global = x_fwd.asarray() + + # Forward + y_dist = Gop_MPI @ x_fwd + assert isinstance(y_dist, pylops_mpi.StackedDistributedArray) + y = y_dist.asarray() + + # Adjoint + x_adj_dist1 = pylops_mpi.DistributedArray( + global_shape=int(np.prod(par["n"])), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_adj_dist1[:] = cp.random.normal(rank, 10, x_adj_dist1.local_shape) + x_adj_dist2 = pylops_mpi.DistributedArray( + global_shape=int(np.prod(par["n"])), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_adj_dist2[:] = cp.random.normal(rank, 20, x_adj_dist2.local_shape) + x_adj_dist3 = pylops_mpi.DistributedArray( + global_shape=int(np.prod(par["n"])), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_adj_dist3[:] = cp.random.normal(rank, 30, x_adj_dist3.local_shape) + x_adj = pylops_mpi.StackedDistributedArray( + distarrays=[x_adj_dist1, x_adj_dist2, x_adj_dist3] + ) + x_adj_global = x_adj.asarray() + y_adj_dist = Gop_MPI.H @ x_adj + assert isinstance(y_adj_dist, pylops_mpi.DistributedArray) + y_adj = y_adj_dist.asarray() + + # Dot test + dottest( + Gop_MPI, x_fwd, y_dist, len(par["n"]) * np.prod(par["n"]), np.prod(par["n"]) + ) + + if rank == 0: + Gop = pylops.Gradient( + dims=par["n"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + assert Gop_MPI.shape == Gop.shape + y_np = Gop @ x_global.get() + y_adj_np = Gop.H @ x_adj_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) diff --git a/tests_nccl/test_stack_nccl.py b/tests_nccl/test_stack_nccl.py new file mode 100644 index 00000000..c727b590 --- /dev/null +++ b/tests_nccl/test_stack_nccl.py @@ -0,0 +1,164 @@ +"""Test the stacking classes + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_stack_nccl.py --with-mpi + +This file employs the same test sets as test_stack under NCCL environment +""" +import numpy as np +import cupy as cp +from numpy.testing import assert_allclose +from mpi4py import MPI +import pytest + +import pylops +import pylops_mpi +from pylops_mpi.utils.dottest import dottest +from pylops_mpi.utils._nccl import initialize_nccl_comm + +nccl_comm = initialize_nccl_comm() + +# imag part is left to future complex-number support +par1 = {'ny': 101, 'nx': 101, 'imag': 0, 'dtype': np.float64} +par2 = {'ny': 301, 'nx': 101, 'imag': 0, 'dtype': np.float64} + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_vstack_nccl(par): + """Test the MPIVStack operator with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) + Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) + VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], ) + + # Broadcasted DistributedArray(global_shape == local_shape) + x = pylops_mpi.DistributedArray(global_shape=par['nx'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.BROADCAST, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + # Scattered DistributedArray + y = pylops_mpi.DistributedArray(global_shape=size * par['ny'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.SCATTER, + dtype=par['dtype'], + engine="cupy") + y[:] = cp.ones(shape=par['ny'], dtype=par['dtype']) + y_global = y.asarray() + + # Forward + x_mat = VStack_MPI @ x + # Adjoint + y_rmat = VStack_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.DistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + # Dot test + dottest(VStack_MPI, x, y, size * par['ny'], par['nx']) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + A = A_gpu.get() + ops = [pylops.MatrixMult(A=((i + 1) * A).astype(par['dtype'])) for i in range(size)] + VStack = pylops.VStack(ops=ops) + x_mat_np = VStack @ x_global.get() + y_rmat_np = VStack.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_stacked_vstack_nccl(par): + """Test the MPIStackedVStack operator with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) + Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) + VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], ) + StackedVStack_MPI = pylops_mpi.MPIStackedVStack([VStack_MPI, VStack_MPI]) + + # Broadcasted DistributedArray(global_shape == local_shape) + x = pylops_mpi.DistributedArray(global_shape=par['nx'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.BROADCAST, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + # Stacked DistributedArray + dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + dist2 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist2[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2]) + y_global = y.asarray() + + x_mat = StackedVStack_MPI @ x + y_rmat = StackedVStack_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.StackedDistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + A = A_gpu.get() + ops = [pylops.MatrixMult(A=((i + 1) * A).astype(par['dtype'])) for i in range(size)] + VStack = pylops.VStack(ops=ops) + VStack_final = pylops.VStack(ops=[VStack, VStack]) + x_mat_np = VStack_final @ x_global.get() + y_rmat_np = VStack_final.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_hstack_nccl(par): + """Test the MPIHStack operator with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) + Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) + HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], ) + + # Scattered DistributedArray + x = pylops_mpi.DistributedArray(global_shape=size * par['nx'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.SCATTER, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + # Broadcasted DistributedArray(global_shape == local_shape) + y = pylops_mpi.DistributedArray(global_shape=par['ny'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.BROADCAST, + dtype=par['dtype'], + engine="cupy") + y[:] = cp.ones(shape=par['ny'], dtype=par['dtype']) + y_global = y.asarray() + + x_mat = HStack_MPI @ x + y_rmat = HStack_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.DistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + ops = [pylops.MatrixMult(A=((i + 1) * A_gpu.get()).astype(par['dtype'])) for i in range(size)] + HStack = pylops.HStack(ops=ops) + x_mat_np = HStack @ x_global.get() + y_rmat_np = HStack.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14)