Skip to content

Commit 3848408

Browse files
committed
remove nccl_comm from Op constructor and take base_nccl_comm from operand x instead
1 parent f80417d commit 3848408

File tree

4 files changed

+8
-14
lines changed

4 files changed

+8
-14
lines changed

pylops_mpi/basicoperators/HStack.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from pylops.utils import DTypeLike
66

77
from pylops_mpi import DistributedArray, MPILinearOperator
8-
from pylops_mpi.DistributedArray import NcclCommunicatorType
98
from .VStack import MPIVStack
109

1110

@@ -90,15 +89,14 @@ class MPIHStack(MPILinearOperator):
9089

9190
def __init__(self, ops: Sequence[LinearOperator],
9291
base_comm: MPI.Comm = MPI.COMM_WORLD,
93-
base_comm_nccl: NcclCommunicatorType = None,
9492
dtype: Optional[DTypeLike] = None):
9593
self.ops = ops
9694
nops = [oper.shape[0] for oper in self.ops]
9795
nops = np.concatenate(base_comm.allgather(nops), axis=0)
9896
if len(set(nops)) > 1:
9997
raise ValueError("Operators have different number of rows")
10098
hops = [oper.H for oper in self.ops]
101-
self.HStack = MPIVStack(ops=hops, base_comm=base_comm, base_comm_nccl=base_comm_nccl, dtype=dtype).H
99+
self.HStack = MPIVStack(ops=hops, base_comm=base_comm, dtype=dtype).H
102100
super().__init__(shape=self.HStack.shape, dtype=self.HStack.dtype, base_comm=base_comm)
103101

104102
def _matvec(self, x: DistributedArray) -> DistributedArray:

pylops_mpi/basicoperators/VStack.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
StackedDistributedArray
1717
)
1818
from pylops_mpi.utils.decorators import reshaped
19-
from pylops_mpi.DistributedArray import NcclCommunicatorType
2019
from pylops_mpi.utils import deps
2120

2221
cupy_message = pylops_deps.cupy_import("the VStack module")
@@ -40,8 +39,6 @@ class MPIVStack(MPILinearOperator):
4039
One or more :class:`pylops.LinearOperator` to be vertically stacked.
4140
base_comm : :obj:`mpi4py.MPI.Comm`, optional
4241
Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
43-
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
44-
NCCL Communicator over which operators and arrays are distributed.
4542
dtype : :obj:`str`, optional
4643
Type of elements in input array.
4744
@@ -110,10 +107,8 @@ class MPIVStack(MPILinearOperator):
110107

111108
def __init__(self, ops: Sequence[LinearOperator],
112109
base_comm: MPI.Comm = MPI.COMM_WORLD,
113-
base_comm_nccl: NcclCommunicatorType = None,
114110
dtype: Optional[DTypeLike] = None):
115111
self.ops = ops
116-
self.base_comm_nccl = base_comm_nccl
117112
nops = np.zeros(len(self.ops), dtype=np.int64)
118113
for iop, oper in enumerate(self.ops):
119114
nops[iop] = oper.shape[0]
@@ -152,8 +147,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
152147
for iop, oper in enumerate(self.ops):
153148
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
154149
y1 = ncp.sum(ncp.vstack(y1), axis=0)
155-
if deps.nccl_enabled and self.base_comm_nccl:
156-
y[:] = nccl_allreduce(self.base_comm_nccl, y1, op=MPI.SUM)
150+
if deps.nccl_enabled and x.base_comm_nccl:
151+
y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM)
157152
else:
158153
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
159154
return y

pylops_mpi/utils/decorators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def wrapper(self, x: DistributedArray):
5454
local_shapes = None
5555
global_shape = getattr(self, "dims")
5656
arr = DistributedArray(global_shape=global_shape,
57+
base_comm_nccl=x.base_comm_nccl,
5758
local_shapes=local_shapes, axis=0,
5859
engine=x.engine, dtype=x.dtype)
5960
arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape)))

tests_nccl/test_stack_nccl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_vstack_nccl(par):
3030
rank = MPI.COMM_WORLD.Get_rank()
3131
A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx']))
3232
Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype']))
33-
VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm)
33+
VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], )
3434

3535
# Broadcasted DistributedArray(global_shape == local_shape)
3636
x = pylops_mpi.DistributedArray(global_shape=par['nx'],
@@ -80,7 +80,7 @@ def test_stacked_vstack_nccl(par):
8080
rank = MPI.COMM_WORLD.Get_rank()
8181
A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx']))
8282
Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype']))
83-
VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], base_comm_nccl=nccl_comm)
83+
VStack_MPI = pylops_mpi.MPIVStack(ops=[Op, ], )
8484
StackedVStack_MPI = pylops_mpi.MPIStackedVStack([VStack_MPI, VStack_MPI])
8585

8686
# Broadcasted DistributedArray(global_shape == local_shape)
@@ -121,13 +121,13 @@ def test_stacked_vstack_nccl(par):
121121

122122
@pytest.mark.mpi(min_size=2)
123123
@pytest.mark.parametrize("par", [(par1), (par2)])
124-
def test_hstack(par):
124+
def test_hstack_nccl(par):
125125
"""Test the MPIHStack operator with NCCL"""
126126
size = MPI.COMM_WORLD.Get_size()
127127
rank = MPI.COMM_WORLD.Get_rank()
128128
A_gpu = cp.ones(shape=(par['ny'], par['nx'])) + par['imag'] * cp.ones(shape=(par['ny'], par['nx']))
129129
Op = pylops.MatrixMult(A=((rank + 1) * A_gpu).astype(par['dtype']))
130-
HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], base_comm_nccl=nccl_comm)
130+
HStack_MPI = pylops_mpi.MPIHStack(ops=[Op, ], )
131131

132132
# Scattered DistributedArray
133133
x = pylops_mpi.DistributedArray(global_shape=size * par['nx'],

0 commit comments

Comments
 (0)