Skip to content

Commit d8e9630

Browse files
committed
Add ncp.ascontiguousarray in allgather
1 parent 6c04af1 commit d8e9630

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

pylops_mpi/utils/_mpi.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"mpi_recv",
77
]
88

9-
from typing import Optional
9+
from typing import List, Optional
1010

1111
from mpi4py import MPI
1212
from pylops.utils import NDArray
@@ -18,7 +18,7 @@ def mpi_allgather(base_comm: MPI.Comm,
1818
send_buf: NDArray,
1919
recv_buf: Optional[NDArray] = None,
2020
engine: str = "numpy",
21-
) -> NDArray:
21+
) -> List[NDArray]:
2222
"""MPI_Allallgather/allallgather
2323
2424
Dispatch allgather routine based on type of input and availability of
@@ -38,8 +38,8 @@ def mpi_allgather(base_comm: MPI.Comm,
3838
3939
Returns
4040
-------
41-
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
42-
A buffer containing the gathered data from all ranks.
41+
recv_buf : :obj:`list`
42+
A list of arrays containing the gathered data from all ranks.
4343
4444
"""
4545
if deps.cuda_aware_mpi_enabled or engine == "numpy":
@@ -48,17 +48,19 @@ def mpi_allgather(base_comm: MPI.Comm,
4848
recvcounts = base_comm.allgather(send_buf.size)
4949
recv_buf = recv_buf if recv_buf else ncp.zeros(sum(recvcounts), dtype=send_buf.dtype)
5050
if len(set(send_shapes)) == 1:
51-
_mpi_calls(base_comm, "Allgather", send_buf.copy(), recv_buf, engine=engine)
51+
_mpi_calls(base_comm, "Allgather", ncp.ascontiguousarray(send_buf), recv_buf, engine=engine)
5252
return [chunk.reshape(send_shapes[0]) for chunk in ncp.split(recv_buf, base_comm.size)]
53-
displs = [0]
54-
for i in range(1, len(recvcounts)):
55-
displs.append(displs[i - 1] + recvcounts[i - 1])
56-
_mpi_calls(base_comm, "Allgatherv", send_buf.copy(),
57-
[recv_buf, recvcounts, displs, MPI._typedict[send_buf.dtype.char]], engine=engine)
58-
return [
59-
recv_buf[displs[i]:displs[i] + recvcounts[i]].reshape(send_shapes[i])
60-
for i in range(base_comm.size)
61-
]
53+
else:
54+
# displs represent the starting offsets in recv_buf where data from each rank will be placed
55+
displs = [0]
56+
for i in range(1, len(recvcounts)):
57+
displs.append(displs[i - 1] + recvcounts[i - 1])
58+
_mpi_calls(base_comm, "Allgatherv", ncp.ascontiguousarray(send_buf),
59+
[recv_buf, recvcounts, displs, MPI._typedict[send_buf.dtype.char]], engine=engine)
60+
return [
61+
recv_buf[displs[i]:displs[i] + recvcounts[i]].reshape(send_shapes[i])
62+
for i in range(base_comm.size)
63+
]
6264
else:
6365
# CuPy with non-CUDA-aware MPI
6466
if recv_buf is None:

0 commit comments

Comments
 (0)