From b567f86905b65213ec9cfcffc33d441726c69b0e Mon Sep 17 00:00:00 2001 From: tharittk Date: Tue, 3 Jun 2025 20:55:18 +0700 Subject: [PATCH 1/8] support nccl in add_ghost_cells and NCCL-VStack --- pylops_mpi/DistributedArray.py | 16 +++- pylops_mpi/basicoperators/VStack.py | 23 +++++- tests_nccl/test_stack_nccl.py | 122 ++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 tests_nccl/test_stack_nccl.py diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 50e54d3b..49178e41 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -748,7 +748,13 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, """ ghosted_array = self.local_array.copy() if cells_front is not None: - total_cells_front = self._allgather(cells_front) + [0] + # TODO: these are metadata (small size). Under current API, it will + # call nccl allgather, should we force it to always use MPI? + cells_fronts = self._allgather(cells_front) + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + total_cells_front = cells_fronts.tolist() + [0] + else: + total_cells_front = cells_fronts + [0] # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) cells_front = total_cells_front[self.rank + 1] if self.rank != 0: @@ -761,10 +767,16 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, f"{self.local_shape[self.axis]} < {cells_front}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_front)}") + # TODO: this array maybe large. Currently it will always use MPI. + # Should we enable NCCL point-point here ? self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis), dest=self.rank + 1, tag=1) if cells_back is not None: - total_cells_back = self._allgather(cells_back) + [0] + cells_backs = self._allgather(cells_back) + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + total_cells_back = cells_backs.tolist() + [0] + else: + total_cells_back = cells_backs + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) cells_back = total_cells_back[self.rank - 1] if self.rank != 0: diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index f869a9ad..26f5acbf 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -6,6 +6,7 @@ from pylops import LinearOperator from pylops.utils import DTypeLike from pylops.utils.backend import get_module +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi import ( MPILinearOperator, @@ -15,6 +16,14 @@ StackedDistributedArray ) from pylops_mpi.utils.decorators import reshaped +from pylops_mpi.DistributedArray import NcclCommunicatorType +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the DistributedArray module") +nccl_message = deps.nccl_import("the DistributedArray module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import nccl_allreduce class MPIVStack(MPILinearOperator): @@ -31,6 +40,8 @@ class MPIVStack(MPILinearOperator): One or more :class:`pylops.LinearOperator` to be vertically stacked. base_comm : :obj:`mpi4py.MPI.Comm`, optional Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional + NCCL Communicator over which operators and arrays are distributed. dtype : :obj:`str`, optional Type of elements in input array. @@ -99,8 +110,10 @@ class MPIVStack(MPILinearOperator): def __init__(self, ops: Sequence[LinearOperator], base_comm: MPI.Comm = MPI.COMM_WORLD, + base_comm_nccl: NcclCommunicatorType = None, dtype: Optional[DTypeLike] = None): self.ops = ops + self.base_comm_nccl = base_comm_nccl nops = np.zeros(len(self.ops), dtype=np.int64) for iop, oper in enumerate(self.ops): nops[iop] = oper.shape[0] @@ -121,7 +134,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}" f"Got {x.partition} instead...") - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + # the output y should use NCCL if the operand x uses it + y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -132,13 +146,16 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, + y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) + if deps.nccl_enabled and self.base_comm_nccl: + y[:] = nccl_allreduce(self.base_comm_nccl, y1, op=MPI.SUM) + else: + y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) return y diff --git a/tests_nccl/test_stack_nccl.py b/tests_nccl/test_stack_nccl.py new file mode 100644 index 00000000..60d1b86e --- /dev/null +++ b/tests_nccl/test_stack_nccl.py @@ -0,0 +1,122 @@ +"""Test the stacking classes + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_stack_nccl.py --with-mpi + +This file employs the same test sets as test_stack under NCCL environment +""" +import numpy as np +import cupy as cp +from numpy.testing import assert_allclose +from mpi4py import MPI +import pytest + +import pylops +import pylops_mpi +from pylops_mpi.utils.dottest import dottest +from pylops_mpi.utils._nccl import initialize_nccl_comm + +nccl_comm = initialize_nccl_comm() + +# imag part is left to future complex-number support +par1 = {'ny': 101, 'nx': 101, 'imag': 0, 'dtype': np.float64} +par2 = {'ny': 301, 'nx': 101, 'imag': 0, 'dtype': np.float64} + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_vstack_nccl(par): + """Test the MPIVStack operator with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) + Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) + VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm) + + # Broadcasted DistributedArray(global_shape == local_shape) + x = pylops_mpi.DistributedArray(global_shape=par['nx'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.BROADCAST, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + # Scattered DistributedArray + y = pylops_mpi.DistributedArray(global_shape=size * par['ny'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.SCATTER, + dtype=par['dtype'], + engine="cupy") + y[:] = cp.ones(shape=par['ny'], dtype=par['dtype']) + y_global = y.asarray() + + # Forward + x_mat = VStack_MPI @ x + # Adjoint + y_rmat = VStack_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.DistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + # Dot test + dottest(VStack_MPI, x, y, size * par['ny'], par['nx']) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + A = A_gpu.get() + ops = [pylops.MatrixMult(A=((i + 1) * A).astype(par['dtype'])) for i in range(size)] + VStack = pylops.VStack(ops=ops) + x_mat_np = VStack @ x_global.get() + y_rmat_np = VStack.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_stacked_vstack_nccl(par): + """Test the MPIStackedVStack operator with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) + Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) + VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm) + StackedVStack_MPI = pylops_mpi.MPIStackedVStack([VStack_MPI, VStack_MPI]) + + # Broadcasted DistributedArray(global_shape == local_shape) + x = pylops_mpi.DistributedArray(global_shape=par['nx'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.BROADCAST, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + # Stacked DistributedArray + dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + dist2 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist2[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2]) + y_global = y.asarray() + + x_mat = StackedVStack_MPI @ x + y_rmat = StackedVStack_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.StackedDistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + A = A_gpu.get() + ops = [pylops.MatrixMult(A=((i + 1) * A).astype(par['dtype'])) for i in range(size)] + VStack = pylops.VStack(ops=ops) + VStack_final = pylops.VStack(ops=[VStack, VStack]) + x_mat_np = VStack_final @ x_global.get() + y_rmat_np = VStack_final.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) + + +# TODO: Test of HStack From 7ae78a9d55ca32eb1ed6ec88d249924cbfc208b6 Mon Sep 17 00:00:00 2001 From: tharittk Date: Tue, 3 Jun 2025 21:20:31 +0700 Subject: [PATCH 2/8] minor import msg fix --- pylops_mpi/basicoperators/VStack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 26f5acbf..5a70b087 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -19,8 +19,8 @@ from pylops_mpi.DistributedArray import NcclCommunicatorType from pylops_mpi.utils import deps -cupy_message = pylops_deps.cupy_import("the DistributedArray module") -nccl_message = deps.nccl_import("the DistributedArray module") +cupy_message = pylops_deps.cupy_import("the VStack module") +nccl_message = deps.nccl_import("the VStack module") if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import nccl_allreduce From f80417dcc21f9a58f1ad6c1f7862905ce0e96a50 Mon Sep 17 00:00:00 2001 From: tharittk Date: Wed, 4 Jun 2025 22:03:58 +0700 Subject: [PATCH 3/8] nccl for HStack Op --- pylops_mpi/basicoperators/HStack.py | 4 ++- tests_nccl/test_stack_nccl.py | 44 ++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/basicoperators/HStack.py b/pylops_mpi/basicoperators/HStack.py index c50c0e6a..a8f1d71a 100644 --- a/pylops_mpi/basicoperators/HStack.py +++ b/pylops_mpi/basicoperators/HStack.py @@ -5,6 +5,7 @@ from pylops.utils import DTypeLike from pylops_mpi import DistributedArray, MPILinearOperator +from pylops_mpi.DistributedArray import NcclCommunicatorType from .VStack import MPIVStack @@ -89,6 +90,7 @@ class MPIHStack(MPILinearOperator): def __init__(self, ops: Sequence[LinearOperator], base_comm: MPI.Comm = MPI.COMM_WORLD, + base_comm_nccl: NcclCommunicatorType = None, dtype: Optional[DTypeLike] = None): self.ops = ops nops = [oper.shape[0] for oper in self.ops] @@ -96,7 +98,7 @@ def __init__(self, ops: Sequence[LinearOperator], if len(set(nops)) > 1: raise ValueError("Operators have different number of rows") hops = [oper.H for oper in self.ops] - self.HStack = MPIVStack(ops=hops, base_comm=base_comm, dtype=dtype).H + self.HStack = MPIVStack(ops=hops, base_comm=base_comm, base_comm_nccl=base_comm_nccl, dtype=dtype).H super().__init__(shape=self.HStack.shape, dtype=self.HStack.dtype, base_comm=base_comm) def _matvec(self, x: DistributedArray) -> DistributedArray: diff --git a/tests_nccl/test_stack_nccl.py b/tests_nccl/test_stack_nccl.py index 60d1b86e..bc06fcdb 100644 --- a/tests_nccl/test_stack_nccl.py +++ b/tests_nccl/test_stack_nccl.py @@ -119,4 +119,46 @@ def test_stacked_vstack_nccl(par): assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) -# TODO: Test of HStack +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_hstack(par): + """Test the MPIHStack operator with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) + Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) + HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], base_comm_nccl=nccl_comm) + + # Scattered DistributedArray + x = pylops_mpi.DistributedArray(global_shape=size * par['nx'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.SCATTER, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + # Broadcasted DistributedArray(global_shape == local_shape) + y = pylops_mpi.DistributedArray(global_shape=par['ny'], + base_comm_nccl=nccl_comm, + partition=pylops_mpi.Partition.BROADCAST, + dtype=par['dtype'], + engine="cupy") + y[:] = cp.ones(shape=par['ny'], dtype=par['dtype']) + y_global = y.asarray() + + x_mat = HStack_MPI @ x + y_rmat = HStack_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.DistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + ops = [pylops.MatrixMult(A=((i + 1) * A_gpu.get()).astype(par['dtype'])) for i in range(size)] + HStack = pylops.HStack(ops=ops) + x_mat_np = HStack @ x_global.get() + y_rmat_np = HStack.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) From 3848408d45842a560f202018167ec9fee58b69d5 Mon Sep 17 00:00:00 2001 From: tharittk Date: Sun, 8 Jun 2025 10:13:04 +0700 Subject: [PATCH 4/8] remove nccl_comm from Op constructor and take base_nccl_comm from operand x instead --- pylops_mpi/basicoperators/HStack.py | 4 +--- pylops_mpi/basicoperators/VStack.py | 9 ++------- pylops_mpi/utils/decorators.py | 1 + tests_nccl/test_stack_nccl.py | 8 ++++---- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pylops_mpi/basicoperators/HStack.py b/pylops_mpi/basicoperators/HStack.py index a8f1d71a..c50c0e6a 100644 --- a/pylops_mpi/basicoperators/HStack.py +++ b/pylops_mpi/basicoperators/HStack.py @@ -5,7 +5,6 @@ from pylops.utils import DTypeLike from pylops_mpi import DistributedArray, MPILinearOperator -from pylops_mpi.DistributedArray import NcclCommunicatorType from .VStack import MPIVStack @@ -90,7 +89,6 @@ class MPIHStack(MPILinearOperator): def __init__(self, ops: Sequence[LinearOperator], base_comm: MPI.Comm = MPI.COMM_WORLD, - base_comm_nccl: NcclCommunicatorType = None, dtype: Optional[DTypeLike] = None): self.ops = ops nops = [oper.shape[0] for oper in self.ops] @@ -98,7 +96,7 @@ def __init__(self, ops: Sequence[LinearOperator], if len(set(nops)) > 1: raise ValueError("Operators have different number of rows") hops = [oper.H for oper in self.ops] - self.HStack = MPIVStack(ops=hops, base_comm=base_comm, base_comm_nccl=base_comm_nccl, dtype=dtype).H + self.HStack = MPIVStack(ops=hops, base_comm=base_comm, dtype=dtype).H super().__init__(shape=self.HStack.shape, dtype=self.HStack.dtype, base_comm=base_comm) def _matvec(self, x: DistributedArray) -> DistributedArray: diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 5a70b087..c4a623c4 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -16,7 +16,6 @@ StackedDistributedArray ) from pylops_mpi.utils.decorators import reshaped -from pylops_mpi.DistributedArray import NcclCommunicatorType from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the VStack module") @@ -40,8 +39,6 @@ class MPIVStack(MPILinearOperator): One or more :class:`pylops.LinearOperator` to be vertically stacked. base_comm : :obj:`mpi4py.MPI.Comm`, optional Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. - base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional - NCCL Communicator over which operators and arrays are distributed. dtype : :obj:`str`, optional Type of elements in input array. @@ -110,10 +107,8 @@ class MPIVStack(MPILinearOperator): def __init__(self, ops: Sequence[LinearOperator], base_comm: MPI.Comm = MPI.COMM_WORLD, - base_comm_nccl: NcclCommunicatorType = None, dtype: Optional[DTypeLike] = None): self.ops = ops - self.base_comm_nccl = base_comm_nccl nops = np.zeros(len(self.ops), dtype=np.int64) for iop, oper in enumerate(self.ops): nops[iop] = oper.shape[0] @@ -152,8 +147,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - if deps.nccl_enabled and self.base_comm_nccl: - y[:] = nccl_allreduce(self.base_comm_nccl, y1, op=MPI.SUM) + if deps.nccl_enabled and x.base_comm_nccl: + y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM) else: y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) return y diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index 457b559b..2273ff52 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -54,6 +54,7 @@ def wrapper(self, x: DistributedArray): local_shapes = None global_shape = getattr(self, "dims") arr = DistributedArray(global_shape=global_shape, + base_comm_nccl=x.base_comm_nccl, local_shapes=local_shapes, axis=0, engine=x.engine, dtype=x.dtype) arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape))) diff --git a/tests_nccl/test_stack_nccl.py b/tests_nccl/test_stack_nccl.py index bc06fcdb..c727b590 100644 --- a/tests_nccl/test_stack_nccl.py +++ b/tests_nccl/test_stack_nccl.py @@ -30,7 +30,7 @@ def test_vstack_nccl(par): rank = MPI.COMM_WORLD.Get_rank() A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) - VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm) + VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], ) # Broadcasted DistributedArray(global_shape == local_shape) x = pylops_mpi.DistributedArray(global_shape=par['nx'], @@ -80,7 +80,7 @@ def test_stacked_vstack_nccl(par): rank = MPI.COMM_WORLD.Get_rank() A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) - VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm) + VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], ) StackedVStack_MPI = pylops_mpi.MPIStackedVStack([VStack_MPI, VStack_MPI]) # Broadcasted DistributedArray(global_shape == local_shape) @@ -121,13 +121,13 @@ def test_stacked_vstack_nccl(par): @pytest.mark.mpi(min_size=2) @pytest.mark.parametrize("par", [(par1), (par2)]) -def test_hstack(par): +def test_hstack_nccl(par): """Test the MPIHStack operator with NCCL""" size = MPI.COMM_WORLD.Get_size() rank = MPI.COMM_WORLD.Get_rank() A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx'])) Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype'])) - HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], base_comm_nccl=nccl_comm) + HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], ) # Scattered DistributedArray x = pylops_mpi.DistributedArray(global_shape=size * par['nx'], From 58f1305aeb61aea58b05223a3fabecdbc459152d Mon Sep 17 00:00:00 2001 From: tharittk Date: Sun, 8 Jun 2025 11:02:11 +0700 Subject: [PATCH 5/8] point-to-point (send/recv) using NCCL. Testsed with BlockDiag & FirstDerivative --- pylops_mpi/DistributedArray.py | 84 +++++++++---- pylops_mpi/LinearOperator.py | 2 + pylops_mpi/basicoperators/BlockDiag.py | 4 +- pylops_mpi/basicoperators/FirstDerivative.py | 20 ++-- pylops_mpi/utils/_nccl.py | 58 ++++++++- tests_nccl/test_blockdiag_nccl.py | 118 +++++++++++++++++++ 6 files changed, 248 insertions(+), 38 deletions(-) create mode 100644 tests_nccl/test_blockdiag_nccl.py diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 49178e41..3a6f3574 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -14,7 +14,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split + from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -503,6 +503,35 @@ def _allgather(self, send_buf, recv_buf=None): self.base_comm.Allgather(send_buf, recv_buf) return recv_buf + def _send(self, send_buf, dest, count=None, tag=None): + """ Send operation + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if count is None: + # assuming sending the whole array + count = send_buf.size + nccl_send(self.base_comm_nccl, send_buf, dest, count) + else: + self.base_comm.Send(send_buf, dest, tag) + + def _recv(self, recv_buf=None, source=0, count=None, tag=None): + """ Receive operation + """ + # NCCL must be called with recv_buf. Size cannot be inferred from + # other arguments and thus cannot dynamically allocated + if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None: + if count is None: + # assuming data will take a space of the whole buffer + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf + else: + # MPI allows a receiver buffer to be optional + if recv_buf is None: + return self.base_comm.recv(source=source, tag=tag) + self.base_comm.Recv(buf=recv_buf, source=source, tag=tag) + return recv_buf + def __neg__(self): arr = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, @@ -747,50 +776,55 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, """ ghosted_array = self.local_array.copy() + ncp = get_module(getattr(self, "engine")) if cells_front is not None: - # TODO: these are metadata (small size). Under current API, it will - # call nccl allgather, should we force it to always use MPI? - cells_fronts = self._allgather(cells_front) - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - total_cells_front = cells_fronts.tolist() + [0] - else: - total_cells_front = cells_fronts + [0] + # cells_front is small array of int. Explicitly use MPI + total_cells_front = self.base_comm.allgather(cells_front) + [0] # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) cells_front = total_cells_front[self.rank + 1] + send_buf = ncp.take(self.local_array, ncp.arange(-cells_front, 0), axis=self.axis) + recv_shapes = self.local_shapes if self.rank != 0: - ghosted_array = np.concatenate([self.base_comm.recv(source=self.rank - 1, tag=1), ghosted_array], - axis=self.axis) - if self.rank != self.size - 1: + # from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1) + # in every dimension except the shape at axis=self.axis + recv_shape = list(recv_shapes[self.rank - 1]) + recv_shape[self.axis] = total_cells_front[self.rank] + recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) + # Some communication can skip if len(recv_buf) = 0 + # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory + if len(recv_buf) != 0: + ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis) + # The skip in sender is to match with what described in receiver + if self.rank != self.size - 1 and len(send_buf) != 0: if cells_front > self.local_shape[self.axis]: raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} " f"should be > {cells_front}: dim({self.axis}) " f"{self.local_shape[self.axis]} < {cells_front}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_front)}") - # TODO: this array maybe large. Currently it will always use MPI. - # Should we enable NCCL point-point here ? - self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis), - dest=self.rank + 1, tag=1) + self._send(send_buf, dest=self.rank + 1, tag=1) if cells_back is not None: - cells_backs = self._allgather(cells_back) - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - total_cells_back = cells_backs.tolist() + [0] - else: - total_cells_back = cells_backs + [0] + total_cells_back = self.base_comm.allgather(cells_back) + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) cells_back = total_cells_back[self.rank - 1] - if self.rank != 0: + send_buf = ncp.take(self.local_array, ncp.arange(cells_back), axis=self.axis) + # Same reasoning as sending cell front applied + recv_shapes = self.local_shapes + if self.rank != 0 and len(send_buf) != 0: if cells_back > self.local_shape[self.axis]: raise ValueError(f"Local Shape at rank={self.rank} along axis={self.axis} " f"should be > {cells_back}: dim({self.axis}) " f"{self.local_shape[self.axis]} < {cells_back}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_back)}") - self.base_comm.send(np.take(self.local_array, np.arange(cells_back), axis=self.axis), - dest=self.rank - 1, tag=0) + self._send(send_buf, dest=self.rank - 1, tag=0) if self.rank != self.size - 1: - ghosted_array = np.append(ghosted_array, self.base_comm.recv(source=self.rank + 1, tag=0), - axis=self.axis) + recv_shape = list(recv_shapes[self.rank + 1]) + recv_shape[self.axis] = total_cells_back[self.rank] + recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) + if len(recv_buf) != 0: + ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0), + axis=self.axis) return ghosted_array def __repr__(self): diff --git a/pylops_mpi/LinearOperator.py b/pylops_mpi/LinearOperator.py index 266e55fe..49077325 100644 --- a/pylops_mpi/LinearOperator.py +++ b/pylops_mpi/LinearOperator.py @@ -86,6 +86,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if self.Op: y = DistributedArray(global_shape=self.shape[0], base_comm=self.base_comm, + base_comm_nccl=x.base_comm_nccl, partition=x.partition, axis=x.axis, engine=x.engine, @@ -123,6 +124,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if self.Op: y = DistributedArray(global_shape=self.shape[1], base_comm=self.base_comm, + base_comm_nccl=x.base_comm_nccl, partition=x.partition, axis=x.axis, engine=x.engine, diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index 28511105..21f4b328 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator], @reshaped(forward=True, stacking=True) def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m, + y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m, mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): diff --git a/pylops_mpi/basicoperators/FirstDerivative.py b/pylops_mpi/basicoperators/FirstDerivative.py index 5adbe284..11f9d85f 100644 --- a/pylops_mpi/basicoperators/FirstDerivative.py +++ b/pylops_mpi/basicoperators/FirstDerivative.py @@ -129,19 +129,19 @@ def _register_multiplications( def _matvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) return self._hmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) return self._hrmatvec(x) @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=1) y_forward = ghosted_x[1:] - ghosted_x[:-1] @@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: @@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1) y_backward = ghosted_x[1:] - ghosted_x[:-1] @@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=1) @@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2]) @@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 @@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2) y_centered = ( @@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=4) diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index c3b02b71..da44726d 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -4,7 +4,9 @@ "nccl_allgather", "nccl_allreduce", "nccl_bcast", - "nccl_asarray" + "nccl_asarray", + "nccl_send", + "nccl_recv" ] from enum import IntEnum @@ -286,3 +288,57 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: chunks[i] = chunks[i].reshape(send_shape)[slicing] # combine back to single global array return cp.concatenate(chunks, axis=axis) + + +def nccl_send(nccl_comm, send_buf, dest, count): + """NCCL equivalent of MPI_Send. Sends a specified number of elements + from the buffer to a destination GPU device. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for point-to-point communication. + send_buf : :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination GPU device. + count : :obj:`int` + Number of elements to send from `send_buf`. + + Returns + ------- + None + """ + nccl_comm.send(send_buf.data.ptr, + count, + cupy_to_nccl_dtype[str(send_buf.dtype)], + dest, + cp.cuda.Stream.null.ptr + ) + + +def nccl_recv(nccl_comm, recv_buf, source, count=None): + """NCCL equivalent of MPI_Recv. Receives data from a source GPU device + into the given buffer. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for point-to-point communication. + recv_buf : :obj:`cupy.ndarray` + The array to store the received data. + source : :obj:`int` + The rank of the source GPU device. + count : :obj:`int`, optional + Number of elements to receive. + + Returns + ------- + None + """ + nccl_comm.recv(recv_buf.data.ptr, + count, + cupy_to_nccl_dtype[str(recv_buf.dtype)], + source, + cp.cuda.Stream.null.ptr + ) diff --git a/tests_nccl/test_blockdiag_nccl.py b/tests_nccl/test_blockdiag_nccl.py new file mode 100644 index 00000000..8c278481 --- /dev/null +++ b/tests_nccl/test_blockdiag_nccl.py @@ -0,0 +1,118 @@ +"""Test the MPIBlockDiag and MPIStackedBlockDiag classes + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_blockdiag_nccl.py --with-mpi + +This file employs the same test sets as test_blockdiag under NCCL environment +""" +from mpi4py import MPI +import numpy as np +import cupy as cp +from numpy.testing import assert_allclose +import pytest + +import pylops +import pylops_mpi +from pylops_mpi.utils.dottest import dottest +from pylops_mpi.utils._nccl import initialize_nccl_comm + +nccl_comm = initialize_nccl_comm() + +par1 = {'ny': 101, 'nx': 101, 'dtype': np.float64} +# par1j = {'ny': 101, 'nx': 101, 'dtype': np.complex128} +par2 = {'ny': 301, 'nx': 101, 'dtype': np.float64} +# par2j = {'ny': 301, 'nx': 101, 'dtype': np.complex128} + +np.random.seed(42) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_blockdiag_nccl(par): + """Test the MPIBlockDiag with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + Op = pylops.MatrixMult(A=((rank + 1) * cp.ones(shape=(par['ny'], par['nx']))).astype(par['dtype'])) + BDiag_MPI = pylops_mpi.MPIBlockDiag(ops=[Op, ], ) + + x = pylops_mpi.DistributedArray(global_shape=size * par['nx'], + base_comm_nccl=nccl_comm, + dtype=par['dtype'], + engine="cupy") + x[:] = cp.ones(shape=par['nx'], dtype=par['dtype']) + x_global = x.asarray() + + y = pylops_mpi.DistributedArray(global_shape=size * par['ny'], + base_comm_nccl=nccl_comm, + dtype=par['dtype'], + engine="cupy") + y[:] = cp.ones(shape=par['ny'], dtype=par['dtype']) + y_global = y.asarray() + + # Forward + x_mat = BDiag_MPI @ x + # Adjoint + y_rmat = BDiag_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.DistributedArray) + assert isinstance(y_rmat, pylops_mpi.DistributedArray) + # Dot test + dottest(BDiag_MPI, x, y, size * par['ny'], size * par['nx']) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in range(size)] + BDiag = pylops.BlockDiag(ops=ops) + + x_mat_np = BDiag @ x_global.get() + y_rmat_np = BDiag.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_stacked_blockdiag_nccl(par): + """Tests for MPIStackedBlogDiag with NCCL""" + size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + Op = pylops.MatrixMult(A=((rank + 1) * cp.ones(shape=(par['ny'], par['nx']))).astype(par['dtype'])) + BDiag_MPI = pylops_mpi.MPIBlockDiag(ops=[Op, ], ) + FirstDeriv_MPI = pylops_mpi.MPIFirstDerivative(dims=(par['ny'], par['nx']), dtype=par['dtype']) + StackedBDiag_MPI = pylops_mpi.MPIStackedBlockDiag(ops=[BDiag_MPI, FirstDeriv_MPI]) + + dist1 = pylops_mpi.DistributedArray(global_shape=size * par['nx'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist2[:] = cp.ones(dist2.local_shape, dtype=par['dtype']) + x = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2]) + x_global = x.asarray() + + dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype']) + dist2 = pylops_mpi.DistributedArray(global_shape=par['nx'] * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy") + dist2[:] = cp.ones(dist2.local_shape, dtype=par['dtype']) + y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2]) + y_global = y.asarray() + + # Forward + x_mat = StackedBDiag_MPI @ x + # Adjoint + y_rmat = StackedBDiag_MPI.H @ y + assert isinstance(x_mat, pylops_mpi.StackedDistributedArray) + assert isinstance(y_rmat, pylops_mpi.StackedDistributedArray) + # Dot test + dottest(StackedBDiag_MPI, x, y, size * par['ny'] + par['nx'] * par['ny'], size * par['nx'] + par['nx'] * par['ny']) + + x_mat_mpi = x_mat.asarray() + y_rmat_mpi = y_rmat.asarray() + + if rank == 0: + ops = [pylops.MatrixMult((i + 1) * np.ones(shape=(par['ny'], par['nx'])).astype(par['dtype'])) for i in range(size)] + BDiag = pylops.BlockDiag(ops=ops) + FirstDeriv = pylops.FirstDerivative(dims=(par['ny'], par['nx']), axis=0, dtype=par['dtype']) + BDiag_final = pylops.BlockDiag([BDiag, FirstDeriv]) + x_mat_np = BDiag_final @ x_global.get() + y_rmat_np = BDiag_final.H @ y_global.get() + assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14) + assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14) From ae7190c7e0d73d55d53dca6ef9d01ad73b9ac9c6 Mon Sep 17 00:00:00 2001 From: tharittk Date: Mon, 9 Jun 2025 20:36:12 +0700 Subject: [PATCH 6/8] small fixes based on some of PR comments --- pylops_mpi/DistributedArray.py | 22 ++++++++++++---------- pylops_mpi/utils/_nccl.py | 12 ------------ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 3a6f3574..71e045b5 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -495,7 +495,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): def _allgather(self, send_buf, recv_buf=None): """Allgather operation """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if deps.nccl_enabled and self.base_comm_nccl: return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) else: if recv_buf is None: @@ -518,13 +518,16 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None): """ Receive operation """ # NCCL must be called with recv_buf. Size cannot be inferred from - # other arguments and thus cannot dynamically allocated + # other arguments and thus cannot be dynamically allocated if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None: - if count is None: - # assuming data will take a space of the whole buffer - count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) - return recv_buf + if recv_buf is not None: + if count is None: + # assuming data will take a space of the whole buffer + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf + else: + raise ValueError("Using recv with NCCL must also supply receiver buffer ") else: # MPI allows a receiver buffer to be optional if recv_buf is None: @@ -773,10 +776,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, ------- ghosted_array : :obj:`numpy.ndarray` Ghosted Array - """ ghosted_array = self.local_array.copy() - ncp = get_module(getattr(self, "engine")) + ncp = get_module(self.engine) if cells_front is not None: # cells_front is small array of int. Explicitly use MPI total_cells_front = self.base_comm.allgather(cells_front) + [0] @@ -790,7 +792,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, recv_shape = list(recv_shapes[self.rank - 1]) recv_shape[self.axis] = total_cells_front[self.rank] recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) - # Some communication can skip if len(recv_buf) = 0 + # Transfer of ghost cells can be skipped if len(recv_buf) = 0 # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory if len(recv_buf) != 0: ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis) diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index da44726d..d183fc58 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -215,10 +215,6 @@ def nccl_bcast(nccl_comm, local_array, index, value) -> None: The index in the array to be broadcasted. value : :obj:`scalar` The value to broadcast (only used by the root GPU, rank 0). - - Returns - ------- - None """ if nccl_comm.rank_id() == 0: local_array[index] = value @@ -304,10 +300,6 @@ def nccl_send(nccl_comm, send_buf, dest, count): The rank of the destination GPU device. count : :obj:`int` Number of elements to send from `send_buf`. - - Returns - ------- - None """ nccl_comm.send(send_buf.data.ptr, count, @@ -331,10 +323,6 @@ def nccl_recv(nccl_comm, recv_buf, source, count=None): The rank of the source GPU device. count : :obj:`int`, optional Number of elements to receive. - - Returns - ------- - None """ nccl_comm.recv(recv_buf.data.ptr, count, From fb14c868b5fb5af4a037e2d04a0ff905ef6278e2 Mon Sep 17 00:00:00 2001 From: tharittk Date: Tue, 10 Jun 2025 21:47:24 +0700 Subject: [PATCH 7/8] nccl support for SecondDerivative and test_derivative_nccl for first and second order --- pylops_mpi/DistributedArray.py | 7 +- pylops_mpi/basicoperators/SecondDerivative.py | 27 +- tests_nccl/test_derivative_nccl.py | 681 ++++++++++++++++++ 3 files changed, 700 insertions(+), 15 deletions(-) create mode 100644 tests_nccl/test_derivative_nccl.py diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 71e045b5..ec74e0b2 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -506,7 +506,7 @@ def _allgather(self, send_buf, recv_buf=None): def _send(self, send_buf, dest, count=None, tag=None): """ Send operation """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if deps.nccl_enabled and self.base_comm_nccl: if count is None: # assuming sending the whole array count = send_buf.size @@ -519,7 +519,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None): """ # NCCL must be called with recv_buf. Size cannot be inferred from # other arguments and thus cannot be dynamically allocated - if deps.nccl_enabled and getattr(self, "base_comm_nccl") and recv_buf is not None: + if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: if recv_buf is not None: if count is None: # assuming data will take a space of the whole buffer @@ -572,6 +572,7 @@ def add(self, dist_array): self._check_mask(dist_array) SumArray = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, + base_comm_nccl=self.base_comm_nccl, dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, @@ -598,6 +599,7 @@ def multiply(self, dist_array): ProductArray = DistributedArray(global_shape=self.global_shape, base_comm=self.base_comm, + base_comm_nccl=self.base_comm_nccl, dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, @@ -748,6 +750,7 @@ def ravel(self, order: Optional[str] = "C"): """ local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes] arr = DistributedArray(global_shape=np.prod(self.global_shape), + base_comm_nccl=self.base_comm_nccl, local_shapes=local_shapes, mask=self.mask, partition=self.partition, diff --git a/pylops_mpi/basicoperators/SecondDerivative.py b/pylops_mpi/basicoperators/SecondDerivative.py index 6c4fb961..ab5a1fbc 100644 --- a/pylops_mpi/basicoperators/SecondDerivative.py +++ b/pylops_mpi/basicoperators/SecondDerivative.py @@ -112,20 +112,20 @@ def _register_multiplications( def _matvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) return self._hmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array) + x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) return self._hrmatvec(x) @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=2) y_forward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == self.size - 1: @@ -136,7 +136,8 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: y[:-2] += x[:-2] @@ -162,8 +163,8 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2) y_backward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == 0: @@ -174,8 +175,8 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) y_backward = ghosted_x[2:] @@ -201,8 +202,8 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == 0: @@ -221,8 +222,8 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) y_centered = ghosted_x[1:-1] diff --git a/tests_nccl/test_derivative_nccl.py b/tests_nccl/test_derivative_nccl.py new file mode 100644 index 00000000..e77348a9 --- /dev/null +++ b/tests_nccl/test_derivative_nccl.py @@ -0,0 +1,681 @@ +"""Test the derivative classes + Designed to run with n GPUs (with 1 MPI process per GPU) + $ mpiexec -n 10 pytest test_derivative_nccl.py --with-mpi + +This file employs the same test sets as test_derivative under NCCL environment +""" + +import numpy as np +import cupy as cp +from mpi4py import MPI +from numpy.testing import assert_allclose +import pytest + +import pylops +import pylops_mpi +from pylops_mpi.utils.dottest import dottest +from pylops_mpi.utils._nccl import initialize_nccl_comm + +nccl_comm = initialize_nccl_comm() + +np.random.seed(42) +rank = MPI.COMM_WORLD.Get_rank() +size = MPI.COMM_WORLD.Get_size() + +par1 = { + "nz": 600, + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par1b = { + "nz": 600, + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par1j = { +# "nz": 600, +# "dz": 1.0, +# "edge": False, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par1e = { + "nz": 600, + "dz": 1.0, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par2 = { + "nz": (100, 151), + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par2b = { + "nz": (100, 151), + "dz": 1.0, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par2j = { +# "nz": (100, 151), +# "dz": 1.0, +# "edge": False, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par2e = { + "nz": (100, 151), + "dz": 1.0, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par3 = { + "nz": (101, 51, 100), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par3b = { + "nz": (101, 51, 100), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par3j = { +# "nz": (101, 51, 100), +# "dz": 0.4, +# "edge": True, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par3e = { + "nz": (101, 51, 100), + "dz": 0.4, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par4 = { + "nz": (79, 101, 50), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par4b = { + "nz": (79, 101, 50), + "dz": 0.4, + "edge": False, + "dtype": np.float64, + "partition": pylops_mpi.Partition.BROADCAST, +} + +# par4j = { +# "nz": (79, 101, 50), +# "dz": 0.4, +# "edge": True, +# "dtype": np.complex128, +# "partition": pylops_mpi.Partition.SCATTER +# } + +par4e = { + "nz": (79, 101, 50), + "dz": 0.4, + "edge": True, + "dtype": np.float64, + "partition": pylops_mpi.Partition.SCATTER, +} + +par5 = { + "n": (120, 101, 60), + "axes": (0, 1, 2), + "weights": (0.7, 0.7, 0.7), + "sampling": (1, 1, 1), + "edge": False, + "dtype": np.float64, +} + +par5e = { + "n": (120, 101, 60), + "axes": (-1, -2, -3), + "weights": (0.7, 0.7, 0.7), + "sampling": (1, 1, 1), + "edge": True, + "dtype": np.float64, +} + +par6 = { + "n": (79, 60, 101), + "axes": (0, 1, 2), + "weights": (1, 1, 1), + "sampling": (0.4, 0.4, 0.4), + "edge": False, + "dtype": np.float64, +} + +par6e = { + "n": (79, 60, 101), + "axes": (-1, -2, -3), + "weights": (1, 1, 1), + "sampling": (0.4, 0.4, 0.4), + "edge": True, + "dtype": np.float64, +} + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_first_derivative_forward(par): + """MPIFirstDerivative operator (forward stencil)""" + Fop_MPI = pylops_mpi.MPIFirstDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Fop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Fop_MPI.H @ x + y_adj = y_adj_dist.asarray() + + # Dot test + dottest(Fop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Fop = pylops.FirstDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Fop_MPI.shape == Fop.shape + y_np = Fop @ x_global.get() + y_adj_np = Fop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_first_derivative_backward(par): + """MPIFirstDerivative operator (backward stencil)""" + Fop_MPI = pylops_mpi.MPIFirstDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Fop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Fop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Fop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Fop = pylops.FirstDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Fop_MPI.shape == Fop.shape + y_np = Fop @ x_global.get() + y_adj_np = Fop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_first_derivative_centered(par): + """MPIFirstDerivative operator (centered stencil)""" + for order in [3, 5]: + Fop_MPI = pylops_mpi.MPIFirstDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="centered", + edge=par["edge"], + order=order, + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Fop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Fop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Fop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Fop = pylops.FirstDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="centered", + edge=par["edge"], + order=order, + dtype=par["dtype"], + ) + assert Fop_MPI.shape == Fop.shape + y_np = Fop @ x_global.get() + y_adj_np = Fop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_second_derivative_forward(par): + """MPISecondDerivative operator (forward stencil)""" + Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Sop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Sop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Sop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Sop = pylops.SecondDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="forward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Sop_MPI.shape == Sop.shape + y_np = Sop @ x_global.get() + y_adj_np = Sop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_second_derivative_backward(par): + """MPISecondDerivative operator (backward stencil)""" + Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Sop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Sop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Sop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Sop = pylops.SecondDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="backward", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Sop_MPI.shape == Sop.shape + y_np = Sop @ x_global.get() + y_adj_np = Sop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par1), + (par1b), + (par1e), + (par2), + (par2b), + (par2e), + (par3), + (par3b), + (par3e), + (par4), + (par4b), + (par4e), + ], +) +def test_second_derivative_centered(par): + """MPISecondDerivative operator (centered stencil)""" + Sop_MPI = pylops_mpi.basicoperators.MPISecondDerivative( + dims=par["nz"], + sampling=par["dz"], + kind="centered", + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["nz"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + partition=par["partition"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Sop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Sop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Sop_MPI, x, y_dist, np.prod(par["nz"]), np.prod(par["nz"])) + + if rank == 0: + Sop = pylops.SecondDerivative( + dims=par["nz"], + axis=0, + sampling=par["dz"], + kind="centered", + edge=par["edge"], + dtype=par["dtype"], + ) + assert Sop_MPI.shape == Sop.shape + y_np = Sop @ x_global.get() + y_adj_np = Sop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par5), (par5e), (par6), (par6e)]) +def test_laplacian(par): + """MPILaplacian Operator""" + for kind in ["forward", "backward", "centered"]: + Lop_MPI = pylops_mpi.basicoperators.MPILaplacian( + dims=par["n"], + axes=par["axes"], + weights=par["weights"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + x = pylops_mpi.DistributedArray( + global_shape=np.prod(par["n"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x[:] = cp.random.normal(rank, 10, x.local_shape) + x_global = x.asarray() + # Forward + y_dist = Lop_MPI @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = Lop_MPI.H @ x + y_adj = y_adj_dist.asarray() + # Dot test + dottest(Lop_MPI, x, y_dist, np.prod(par["n"]), np.prod(par["n"])) + + if rank == 0: + Lop = pylops.Laplacian( + dims=par["n"], + axes=par["axes"], + weights=par["weights"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + assert Lop_MPI.shape == Lop.shape + y_np = Lop @ x_global.get() + y_adj_np = Lop.H @ x_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par5), (par5e), (par6), (par6e)]) +def test_gradient(par): + """MPIGradient Operator""" + for kind in ["forward", "backward", "centered"]: + Gop_MPI = pylops_mpi.basicoperators.MPIGradient( + dims=par["n"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + x_fwd = pylops_mpi.DistributedArray( + global_shape=np.prod(par["n"]), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_fwd[:] = cp.random.normal(rank, 10, x_fwd.local_shape) + x_global = x_fwd.asarray() + + # Forward + y_dist = Gop_MPI @ x_fwd + assert isinstance(y_dist, pylops_mpi.StackedDistributedArray) + y = y_dist.asarray() + + # Adjoint + x_adj_dist1 = pylops_mpi.DistributedArray( + global_shape=int(np.prod(par["n"])), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_adj_dist1[:] = cp.random.normal(rank, 10, x_adj_dist1.local_shape) + x_adj_dist2 = pylops_mpi.DistributedArray( + global_shape=int(np.prod(par["n"])), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_adj_dist2[:] = cp.random.normal(rank, 20, x_adj_dist2.local_shape) + x_adj_dist3 = pylops_mpi.DistributedArray( + global_shape=int(np.prod(par["n"])), + base_comm_nccl=nccl_comm, + dtype=par["dtype"], + engine="cupy", + ) + x_adj_dist3[:] = cp.random.normal(rank, 30, x_adj_dist3.local_shape) + x_adj = pylops_mpi.StackedDistributedArray( + distarrays=[x_adj_dist1, x_adj_dist2, x_adj_dist3] + ) + x_adj_global = x_adj.asarray() + y_adj_dist = Gop_MPI.H @ x_adj + assert isinstance(y_adj_dist, pylops_mpi.DistributedArray) + y_adj = y_adj_dist.asarray() + + # Dot test + dottest( + Gop_MPI, x_fwd, y_dist, len(par["n"]) * np.prod(par["n"]), np.prod(par["n"]) + ) + + if rank == 0: + Gop = pylops.Gradient( + dims=par["n"], + sampling=par["sampling"], + kind=kind, + edge=par["edge"], + dtype=par["dtype"], + ) + assert Gop_MPI.shape == Gop.shape + y_np = Gop @ x_global.get() + y_adj_np = Gop.H @ x_adj_global.get() + assert_allclose(y.get(), y_np, rtol=1e-14) + assert_allclose(y_adj.get(), y_adj_np, rtol=1e-14) From d7d07ab21ebe3bd08cf221774883a245bf4c8315 Mon Sep 17 00:00:00 2001 From: tharittk Date: Tue, 17 Jun 2025 21:41:54 +0700 Subject: [PATCH 8/8] explicitly pass x.base_comm to DistributedArray as suggested in PR --- pylops_mpi/DistributedArray.py | 1 + pylops_mpi/basicoperators/BlockDiag.py | 4 ++-- pylops_mpi/basicoperators/FirstDerivative.py | 20 +++++++++---------- pylops_mpi/basicoperators/SecondDerivative.py | 16 +++++++-------- pylops_mpi/basicoperators/VStack.py | 4 ++-- pylops_mpi/utils/decorators.py | 1 + 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index ec74e0b2..6e67e1fc 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -750,6 +750,7 @@ def ravel(self, order: Optional[str] = "C"): """ local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes] arr = DistributedArray(global_shape=np.prod(self.global_shape), + base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl, local_shapes=local_shapes, mask=self.mask, diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index 21f4b328..27d9b813 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -121,7 +121,7 @@ def __init__(self, ops: Sequence[LinearOperator], @reshaped(forward=True, stacking=True) def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, + y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -133,7 +133,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m, + y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_m, mask=self.mask, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): diff --git a/pylops_mpi/basicoperators/FirstDerivative.py b/pylops_mpi/basicoperators/FirstDerivative.py index 11f9d85f..b25f43d7 100644 --- a/pylops_mpi/basicoperators/FirstDerivative.py +++ b/pylops_mpi/basicoperators/FirstDerivative.py @@ -129,19 +129,19 @@ def _register_multiplications( def _matvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hrmatvec(x) @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=1) y_forward = ghosted_x[1:] - ghosted_x[:-1] @@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: @@ -171,7 +171,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1) y_backward = ghosted_x[1:] - ghosted_x[:-1] @@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=1) @@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2]) @@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 @@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2) y_centered = ( @@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=4) diff --git a/pylops_mpi/basicoperators/SecondDerivative.py b/pylops_mpi/basicoperators/SecondDerivative.py index ab5a1fbc..bfe09b78 100644 --- a/pylops_mpi/basicoperators/SecondDerivative.py +++ b/pylops_mpi/basicoperators/SecondDerivative.py @@ -112,19 +112,19 @@ def _register_multiplications( def _matvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hmatvec(x) def _rmatvec(self, x: DistributedArray) -> DistributedArray: # If Partition.BROADCAST, then convert to Partition.SCATTER if x.partition is Partition.BROADCAST: - x = DistributedArray.to_dist(x=x.local_array, base_comm_nccl=x.base_comm_nccl) + x = DistributedArray.to_dist(x=x.local_array, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl) return self._hrmatvec(x) @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=2) y_forward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] @@ -136,7 +136,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: @@ -163,7 +163,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2) y_backward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] @@ -175,7 +175,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) @@ -202,7 +202,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] @@ -222,7 +222,7 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, base_comm_nccl=x.base_comm_nccl, + y = DistributedArray(global_shape=x.global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index c4a623c4..58581565 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -130,7 +130,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}" f"Got {x.partition} instead...") # the output y should use NCCL if the operand x uses it - y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, + y = DistributedArray(global_shape=self.shape[0], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -141,7 +141,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, + y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index 2273ff52..21b16906 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -54,6 +54,7 @@ def wrapper(self, x: DistributedArray): local_shapes = None global_shape = getattr(self, "dims") arr = DistributedArray(global_shape=global_shape, + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, local_shapes=local_shapes, axis=0, engine=x.engine, dtype=x.dtype)