Skip to content

Commit 0c7136f

Browse files
committed
Fix Fredholm bug, Pass all 306 tests in NumPy+MPI and CuPy+NCCL
1 parent 3609220 commit 0c7136f

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
130130
# gather results
131131
# TODO: _allgather is supposed to be private to DistributedArray
132132
# but so far, we do not take base_comm_nccl as an argument to Op.
133-
# For consistency, y._allgather has to be call here.
134-
# we can do if else for x.base_comm_nccl, but that means
133+
# For consistency, y._allgather has to be called here.
134+
# Alternatively, we can also do if-else checking x.base_comm_nccl, but that means
135135
# we have to call function from _nccl.py
136-
# y[:] = np.vstack(y._allgather(y1)).ravel()
137-
recv = y._allgather(y1)
138-
y[:] = recv.ravel()
136+
y[:] = ncp.vstack(y._allgather(y1)).ravel()
139137
return y
140138

141139
def _rmatvec(self, x: NDArray) -> NDArray:
@@ -172,11 +170,15 @@ def _rmatvec(self, x: NDArray) -> NDArray:
172170
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
173171

174172
# gather results
175-
recv = y._allgather(y1)
176-
if self.usematmul:
177-
# unrolling like DistributedArray asarray()
173+
recv = y._allgather(y1)
174+
# TODO: current of _allgather will call non-buffered MPI-AllGather (sub-optimal for CuPy+MPI)
175+
# which returns a list (not flatten) and does not require unrolling
176+
if self.usematmul and isinstance(recv, ncp.ndarray) :
177+
# unrolling
178178
chunk_size = self.ny * self.nz
179-
recv = ncp.vstack([recv[i*chunk_size: (i+1)*chunk_size].reshape(self.nz, self.ny).T for i in range((len(recv)+chunk_size-1)//chunk_size)])
180-
179+
num_partition = (len(recv)+chunk_size-1)//chunk_size
180+
recv = ncp.vstack([recv[i*chunk_size: (i+1)*chunk_size].reshape(self.nz, self.ny).T for i in range(num_partition)])
181+
else:
182+
recv = ncp.vstack(recv)
181183
y[:] = recv.ravel()
182184
return y

0 commit comments

Comments
 (0)