66 "mpi_recv" ,
77]
88
9- from typing import Optional
9+ from typing import List , Optional
1010
1111from mpi4py import MPI
1212from pylops .utils import NDArray
@@ -18,7 +18,7 @@ def mpi_allgather(base_comm: MPI.Comm,
1818 send_buf : NDArray ,
1919 recv_buf : Optional [NDArray ] = None ,
2020 engine : str = "numpy" ,
21- ) -> NDArray :
21+ ) -> List [ NDArray ] :
2222 """MPI_Allallgather/allallgather
2323
2424 Dispatch allgather routine based on type of input and availability of
@@ -38,8 +38,8 @@ def mpi_allgather(base_comm: MPI.Comm,
3838
3939 Returns
4040 -------
41- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray `
42- A buffer containing the gathered data from all ranks.
41+ recv_buf : :obj:`list `
42+ A list of arrays containing the gathered data from all ranks.
4343
4444 """
4545 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
@@ -48,17 +48,19 @@ def mpi_allgather(base_comm: MPI.Comm,
4848 recvcounts = base_comm .allgather (send_buf .size )
4949 recv_buf = recv_buf if recv_buf else ncp .zeros (sum (recvcounts ), dtype = send_buf .dtype )
5050 if len (set (send_shapes )) == 1 :
51- _mpi_calls (base_comm , "Allgather" , send_buf . copy ( ), recv_buf , engine = engine )
51+ _mpi_calls (base_comm , "Allgather" , ncp . ascontiguousarray ( send_buf ), recv_buf , engine = engine )
5252 return [chunk .reshape (send_shapes [0 ]) for chunk in ncp .split (recv_buf , base_comm .size )]
53- displs = [0 ]
54- for i in range (1 , len (recvcounts )):
55- displs .append (displs [i - 1 ] + recvcounts [i - 1 ])
56- _mpi_calls (base_comm , "Allgatherv" , send_buf .copy (),
57- [recv_buf , recvcounts , displs , MPI ._typedict [send_buf .dtype .char ]], engine = engine )
58- return [
59- recv_buf [displs [i ]:displs [i ] + recvcounts [i ]].reshape (send_shapes [i ])
60- for i in range (base_comm .size )
61- ]
53+ else :
54+ # displs represent the starting offsets in recv_buf where data from each rank will be placed
55+ displs = [0 ]
56+ for i in range (1 , len (recvcounts )):
57+ displs .append (displs [i - 1 ] + recvcounts [i - 1 ])
58+ _mpi_calls (base_comm , "Allgatherv" , ncp .ascontiguousarray (send_buf ),
59+ [recv_buf , recvcounts , displs , MPI ._typedict [send_buf .dtype .char ]], engine = engine )
60+ return [
61+ recv_buf [displs [i ]:displs [i ] + recvcounts [i ]].reshape (send_shapes [i ])
62+ for i in range (base_comm .size )
63+ ]
6264 else :
6365 # CuPy with non-CUDA-aware MPI
6466 if recv_buf is None :
0 commit comments