Skip to content

Commit 33121a5

Browse files
committed
temporary use CPU buffer for CuPy + MPI in inf and -inf norm
1 parent 8d458b0 commit 33121a5

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -694,14 +694,25 @@ def _compute_vector_norm(self, local_array: NDArray,
694694
recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64))
695695
elif ord == ncp.inf:
696696
# Calculate max followed by max reduction
697-
recv_buf = self._allreduce_subcomm(ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64),
698-
recv_buf, op=MPI.MAX)
699-
recv_buf = ncp.squeeze(recv_buf, axis=axis)
697+
# TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly
698+
# with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
699+
send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64)
700+
if self.engine=="cupy" and self.base_comm_nccl is None:
701+
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
702+
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
703+
else:
704+
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX)
705+
recv_buf = ncp.squeeze(recv_buf, axis=axis)
700706
elif ord == -ncp.inf:
701707
# Calculate min followed by min reduction
702-
recv_buf = self._allreduce_subcomm(ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64),
703-
recv_buf, op=MPI.MIN)
704-
recv_buf = ncp.squeeze(recv_buf, axis=axis)
708+
# TODO (tharitt): see the comment above in infinity norm
709+
send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64)
710+
if self.engine == "cupy" and self.base_comm_nccl is None:
711+
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN)
712+
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
713+
else:
714+
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN)
715+
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
705716

706717
else:
707718
recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis))

tests/test_distributedarray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ def test_distributed_norm(par):
205205
np.linalg.norm(par['x'], ord=1, axis=par['axis']), rtol=1e-14)
206206

207207
# TODO (tharitt): FAIL with CuPy + MPI for inf norm
208-
# assert_allclose(arr.norm(ord=np.inf, axis=par['axis']),
209-
# np.linalg.norm(par['x'], ord=np.inf, axis=par['axis']), rtol=1e-14)
208+
assert_allclose(arr.norm(ord=np.inf, axis=par['axis']),
209+
np.linalg.norm(par['x'], ord=np.inf, axis=par['axis']), rtol=1e-14)
210210
assert_allclose(arr.norm(), np.linalg.norm(par['x'].flatten()), rtol=1e-13)
211211

212212

@@ -335,7 +335,7 @@ def test_distributed_maskednorm(par):
335335
np.linalg.norm(par['x'], ord=1, axis=par['axis']) / nsub, rtol=1e-14)
336336

337337
# TODO (tharitt): Fail with CuPy + MPI
338-
# assert_allclose(arr.norm(ord=np.inf, axis=par['axis']),
339-
# np.linalg.norm(par['x'], ord=np.inf, axis=par['axis']), rtol=1e-14)
338+
assert_allclose(arr.norm(ord=np.inf, axis=par['axis']),
339+
np.linalg.norm(par['x'], ord=np.inf, axis=par['axis']), rtol=1e-14)
340340
assert_allclose(arr.norm(ord=2, axis=par['axis']),
341341
np.linalg.norm(par['x'], ord=2, axis=par['axis']) / np.sqrt(nsub), rtol=1e-13)

0 commit comments

Comments
 (0)