Skip to content

Commit 693f078

Browse files
committed
feat: adapted all comm calls in VStack to new method signatures
1 parent f362436 commit 693f078

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

pylops_mpi/basicoperators/VStack.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pylops import LinearOperator
77
from pylops.utils import DTypeLike
88
from pylops.utils.backend import get_module
9-
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
109

1110
from pylops_mpi import (
1211
MPILinearOperator,
@@ -17,13 +16,6 @@
1716
)
1817
from pylops_mpi.Distributed import DistributedMixIn
1918
from pylops_mpi.utils.decorators import reshaped
20-
from pylops_mpi.utils import deps
21-
22-
cupy_message = pylops_deps.cupy_import("the VStack module")
23-
nccl_message = deps.nccl_import("the VStack module")
24-
25-
if nccl_message is None and cupy_message is None:
26-
from pylops_mpi.utils._nccl import nccl_allreduce
2719

2820

2921
class MPIVStack(DistributedMixIn, MPILinearOperator):
@@ -142,19 +134,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
142134
@reshaped(forward=False, stacking=True)
143135
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
144136
ncp = get_module(x.engine)
145-
# TODO: consider adding base_comm, base_comm_nccl, engine to the
146-
# input parameters of _allreduce instead of relying on self
147-
self.base_comm, self.base_comm_nccl, self.engine = \
148-
x.base_comm, x.base_comm_nccl, x.engine
149-
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm,
137+
y = DistributedArray(global_shape=self.shape[1],
138+
base_comm=x.base_comm,
150139
base_comm_nccl=x.base_comm_nccl,
151140
partition=Partition.BROADCAST,
152-
engine=x.engine, dtype=self.dtype)
141+
engine=x.engine,
142+
dtype=self.dtype)
153143
y1 = []
154144
for iop, oper in enumerate(self.ops):
155145
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
156146
y1 = ncp.sum(ncp.vstack(y1), axis=0)
157-
y[:] = self._allreduce(y1, op=MPI.SUM)
147+
y[:] = self._allreduce(x.base_comm, x.base_comm_nccl,
148+
y1, op=MPI.SUM, engine=x.engine)
158149
return y
159150

160151

0 commit comments

Comments
 (0)