Skip to content

Commit b567f86

Browse files
committed
support nccl in add_ghost_cells and NCCL-VStack
1 parent 3585f36 commit b567f86

File tree

3 files changed

+156
-5
lines changed

3 files changed

+156
-5
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,13 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
748748
"""
749749
ghosted_array = self.local_array.copy()
750750
if cells_front is not None:
751-
total_cells_front = self._allgather(cells_front) + [0]
751+
# TODO: these are metadata (small size). Under current API, it will
752+
# call nccl allgather, should we force it to always use MPI?
753+
cells_fronts = self._allgather(cells_front)
754+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
755+
total_cells_front = cells_fronts.tolist() + [0]
756+
else:
757+
total_cells_front = cells_fronts + [0]
752758
# Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
753759
cells_front = total_cells_front[self.rank + 1]
754760
if self.rank != 0:
@@ -761,10 +767,16 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
761767
f"{self.local_shape[self.axis]} < {cells_front}; "
762768
f"to achieve this use NUM_PROCESSES <= "
763769
f"{max(1, self.global_shape[self.axis] // cells_front)}")
770+
# TODO: this array maybe large. Currently it will always use MPI.
771+
# Should we enable NCCL point-point here ?
764772
self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis),
765773
dest=self.rank + 1, tag=1)
766774
if cells_back is not None:
767-
total_cells_back = self._allgather(cells_back) + [0]
775+
cells_backs = self._allgather(cells_back)
776+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
777+
total_cells_back = cells_backs.tolist() + [0]
778+
else:
779+
total_cells_back = cells_backs + [0]
768780
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
769781
cells_back = total_cells_back[self.rank - 1]
770782
if self.rank != 0:

pylops_mpi/basicoperators/VStack.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pylops import LinearOperator
77
from pylops.utils import DTypeLike
88
from pylops.utils.backend import get_module
9+
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
910

1011
from pylops_mpi import (
1112
MPILinearOperator,
@@ -15,6 +16,14 @@
1516
StackedDistributedArray
1617
)
1718
from pylops_mpi.utils.decorators import reshaped
19+
from pylops_mpi.DistributedArray import NcclCommunicatorType
20+
from pylops_mpi.utils import deps
21+
22+
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
23+
nccl_message = deps.nccl_import("the DistributedArray module")
24+
25+
if nccl_message is None and cupy_message is None:
26+
from pylops_mpi.utils._nccl import nccl_allreduce
1827

1928

2029
class MPIVStack(MPILinearOperator):
@@ -31,6 +40,8 @@ class MPIVStack(MPILinearOperator):
3140
One or more :class:`pylops.LinearOperator` to be vertically stacked.
3241
base_comm : :obj:`mpi4py.MPI.Comm`, optional
3342
Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
43+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
44+
NCCL Communicator over which operators and arrays are distributed.
3445
dtype : :obj:`str`, optional
3546
Type of elements in input array.
3647
@@ -99,8 +110,10 @@ class MPIVStack(MPILinearOperator):
99110

100111
def __init__(self, ops: Sequence[LinearOperator],
101112
base_comm: MPI.Comm = MPI.COMM_WORLD,
113+
base_comm_nccl: NcclCommunicatorType = None,
102114
dtype: Optional[DTypeLike] = None):
103115
self.ops = ops
116+
self.base_comm_nccl = base_comm_nccl
104117
nops = np.zeros(len(self.ops), dtype=np.int64)
105118
for iop, oper in enumerate(self.ops):
106119
nops[iop] = oper.shape[0]
@@ -121,7 +134,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
121134
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
122135
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
123136
f"Got {x.partition} instead...")
124-
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
137+
# the output y should use NCCL if the operand x uses it
138+
y = DistributedArray(global_shape=self.shape[0], base_comm_nccl=x.base_comm_nccl, local_shapes=self.local_shapes_n,
125139
engine=x.engine, dtype=self.dtype)
126140
y1 = []
127141
for iop, oper in enumerate(self.ops):
@@ -132,13 +146,16 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
132146
@reshaped(forward=False, stacking=True)
133147
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
134148
ncp = get_module(x.engine)
135-
y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST,
149+
y = DistributedArray(global_shape=self.shape[1], base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
136150
engine=x.engine, dtype=self.dtype)
137151
y1 = []
138152
for iop, oper in enumerate(self.ops):
139153
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
140154
y1 = ncp.sum(ncp.vstack(y1), axis=0)
141-
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
155+
if deps.nccl_enabled and self.base_comm_nccl:
156+
y[:] = nccl_allreduce(self.base_comm_nccl, y1, op=MPI.SUM)
157+
else:
158+
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
142159
return y
143160

144161

tests_nccl/test_stack_nccl.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Test the stacking classes
2+
Designed to run with n GPUs (with 1 MPI process per GPU)
3+
$ mpiexec -n 10 pytest test_stack_nccl.py --with-mpi
4+
5+
This file employs the same test sets as test_stack under NCCL environment
6+
"""
7+
import numpy as np
8+
import cupy as cp
9+
from numpy.testing import assert_allclose
10+
from mpi4py import MPI
11+
import pytest
12+
13+
import pylops
14+
import pylops_mpi
15+
from pylops_mpi.utils.dottest import dottest
16+
from pylops_mpi.utils._nccl import initialize_nccl_comm
17+
18+
nccl_comm = initialize_nccl_comm()
19+
20+
# imag part is left to future complex-number support
21+
par1 = {'ny': 101, 'nx': 101, 'imag': 0, 'dtype': np.float64}
22+
par2 = {'ny': 301, 'nx': 101, 'imag': 0, 'dtype': np.float64}
23+
24+
25+
@pytest.mark.mpi(min_size=2)
26+
@pytest.mark.parametrize("par", [(par1), (par2)])
27+
def test_vstack_nccl(par):
28+
"""Test the MPIVStack operator with NCCL"""
29+
size = MPI.COMM_WORLD.Get_size()
30+
rank = MPI.COMM_WORLD.Get_rank()
31+
A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx']))
32+
Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype']))
33+
VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm)
34+
35+
# Broadcasted DistributedArray(global_shape == local_shape)
36+
x = pylops_mpi.DistributedArray(global_shape=par['nx'],
37+
base_comm_nccl=nccl_comm,
38+
partition=pylops_mpi.Partition.BROADCAST,
39+
dtype=par['dtype'],
40+
engine="cupy")
41+
x[:] = cp.ones(shape=par['nx'], dtype=par['dtype'])
42+
x_global = x.asarray()
43+
44+
# Scattered DistributedArray
45+
y = pylops_mpi.DistributedArray(global_shape=size * par['ny'],
46+
base_comm_nccl=nccl_comm,
47+
partition=pylops_mpi.Partition.SCATTER,
48+
dtype=par['dtype'],
49+
engine="cupy")
50+
y[:] = cp.ones(shape=par['ny'], dtype=par['dtype'])
51+
y_global = y.asarray()
52+
53+
# Forward
54+
x_mat = VStack_MPI @ x
55+
# Adjoint
56+
y_rmat = VStack_MPI.H @ y
57+
assert isinstance(x_mat, pylops_mpi.DistributedArray)
58+
assert isinstance(y_rmat, pylops_mpi.DistributedArray)
59+
# Dot test
60+
dottest(VStack_MPI, x, y, size * par['ny'], par['nx'])
61+
62+
x_mat_mpi = x_mat.asarray()
63+
y_rmat_mpi = y_rmat.asarray()
64+
65+
if rank == 0:
66+
A = A_gpu.get()
67+
ops = [pylops.MatrixMult(A=((i + 1) * A).astype(par['dtype'])) for i in range(size)]
68+
VStack = pylops.VStack(ops=ops)
69+
x_mat_np = VStack @ x_global.get()
70+
y_rmat_np = VStack.H @ y_global.get()
71+
assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14)
72+
assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14)
73+
74+
75+
@pytest.mark.mpi(min_size=2)
76+
@pytest.mark.parametrize("par", [(par1), (par2)])
77+
def test_stacked_vstack_nccl(par):
78+
"""Test the MPIStackedVStack operator with NCCL"""
79+
size = MPI.COMM_WORLD.Get_size()
80+
rank = MPI.COMM_WORLD.Get_rank()
81+
A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx']))
82+
Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype']))
83+
VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm)
84+
StackedVStack_MPI = pylops_mpi.MPIStackedVStack([VStack_MPI, VStack_MPI])
85+
86+
# Broadcasted DistributedArray(global_shape == local_shape)
87+
x = pylops_mpi.DistributedArray(global_shape=par['nx'],
88+
base_comm_nccl=nccl_comm,
89+
partition=pylops_mpi.Partition.BROADCAST,
90+
dtype=par['dtype'],
91+
engine="cupy")
92+
x[:] = cp.ones(shape=par['nx'], dtype=par['dtype'])
93+
x_global = x.asarray()
94+
95+
# Stacked DistributedArray
96+
dist1 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy")
97+
dist1[:] = cp.ones(dist1.local_shape, dtype=par['dtype'])
98+
dist2 = pylops_mpi.DistributedArray(global_shape=size * par['ny'], base_comm_nccl=nccl_comm, dtype=par['dtype'], engine="cupy")
99+
dist2[:] = cp.ones(dist1.local_shape, dtype=par['dtype'])
100+
y = pylops_mpi.StackedDistributedArray(distarrays=[dist1, dist2])
101+
y_global = y.asarray()
102+
103+
x_mat = StackedVStack_MPI @ x
104+
y_rmat = StackedVStack_MPI.H @ y
105+
assert isinstance(x_mat, pylops_mpi.StackedDistributedArray)
106+
assert isinstance(y_rmat, pylops_mpi.DistributedArray)
107+
108+
x_mat_mpi = x_mat.asarray()
109+
y_rmat_mpi = y_rmat.asarray()
110+
111+
if rank == 0:
112+
A = A_gpu.get()
113+
ops = [pylops.MatrixMult(A=((i + 1) * A).astype(par['dtype'])) for i in range(size)]
114+
VStack = pylops.VStack(ops=ops)
115+
VStack_final = pylops.VStack(ops=[VStack, VStack])
116+
x_mat_np = VStack_final @ x_global.get()
117+
y_rmat_np = VStack_final.H @ y_global.get()
118+
assert_allclose(x_mat_mpi.get(), x_mat_np, rtol=1e-14)
119+
assert_allclose(y_rmat_mpi.get(), y_rmat_np, rtol=1e-14)
120+
121+
122+
# TODO: Test of HStack

0 commit comments

Comments
 (0)