|
6 | 6 | from pylops import LinearOperator |
7 | 7 | from pylops.utils import DTypeLike |
8 | 8 | from pylops.utils.backend import get_module |
9 | | -from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils |
10 | 9 |
|
11 | 10 | from pylops_mpi import ( |
12 | 11 | MPILinearOperator, |
|
17 | 16 | ) |
18 | 17 | from pylops_mpi.Distributed import DistributedMixIn |
19 | 18 | from pylops_mpi.utils.decorators import reshaped |
20 | | -from pylops_mpi.utils import deps |
21 | | - |
22 | | -cupy_message = pylops_deps.cupy_import("the VStack module") |
23 | | -nccl_message = deps.nccl_import("the VStack module") |
24 | | - |
25 | | -if nccl_message is None and cupy_message is None: |
26 | | - from pylops_mpi.utils._nccl import nccl_allreduce |
27 | 19 |
|
28 | 20 |
|
29 | 21 | class MPIVStack(DistributedMixIn, MPILinearOperator): |
@@ -142,19 +134,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: |
142 | 134 | @reshaped(forward=False, stacking=True) |
143 | 135 | def _rmatvec(self, x: DistributedArray) -> DistributedArray: |
144 | 136 | ncp = get_module(x.engine) |
145 | | - # TODO: consider adding base_comm, base_comm_nccl, engine to the |
146 | | - # input parameters of _allreduce instead of relying on self |
147 | | - self.base_comm, self.base_comm_nccl, self.engine = \ |
148 | | - x.base_comm, x.base_comm_nccl, x.engine |
149 | | - y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, |
| 137 | + y = DistributedArray(global_shape=self.shape[1], |
| 138 | + base_comm=x.base_comm, |
150 | 139 | base_comm_nccl=x.base_comm_nccl, |
151 | 140 | partition=Partition.BROADCAST, |
152 | | - engine=x.engine, dtype=self.dtype) |
| 141 | + engine=x.engine, |
| 142 | + dtype=self.dtype) |
153 | 143 | y1 = [] |
154 | 144 | for iop, oper in enumerate(self.ops): |
155 | 145 | y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) |
156 | 146 | y1 = ncp.sum(ncp.vstack(y1), axis=0) |
157 | | - y[:] = self._allreduce(y1, op=MPI.SUM) |
| 147 | + y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, |
| 148 | + y1, op=MPI.SUM, engine=x.engine) |
158 | 149 | return y |
159 | 150 |
|
160 | 151 |
|
|
0 commit comments