22 "mpi_allgather" ,
33 "mpi_allreduce" ,
44 "mpi_bcast" ,
5- # "mpi_asarray",
65 "mpi_send" ,
76 "mpi_recv" ,
87]
98
10- from typing import Optional
9+ from typing import Optional , Union
1110
1211import numpy as np
1312from mpi4py import MPI
13+ from pylops .utils import NDArray
1414from pylops .utils .backend import get_module
1515from pylops_mpi .utils import deps
1616from pylops_mpi .utils ._common import _prepare_allgather_inputs , _unroll_allgather_recv
1717
1818
1919def mpi_allgather (base_comm : MPI .Comm ,
20- send_buf , recv_buf = None ,
21- engine : Optional [str ] = "numpy" ) -> np .ndarray :
20+ send_buf : NDArray ,
21+ recv_buf : Optional [NDArray ] = None ,
22+ engine : str = "numpy" ,
23+ ) -> NDArray :
24+ """MPI_Allallgather/allallgather
2225
26+ Dispatch allgather routine based on type of input and availability of
27+ CUDA-Aware MPI
28+
29+ Parameters
30+ ----------
31+ base_comm : :obj:`MPI.Comm`
32+ Base MPI Communicator.
33+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
34+ The data buffer from the local rank to be gathered.
35+ recv_buf : :obj:`cupy.ndarray`, optional
36+ The buffer to store the result of the gathering. If None,
37+ a new buffer will be allocated with the appropriate shape.
38+ engine : :obj:`str`, optional
39+ Engine used to store array (``numpy`` or ``cupy``)
40+
41+ Returns
42+ -------
43+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
44+ A buffer containing the gathered data from all ranks.
45+
46+ """
2347 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
2448 send_shapes = base_comm .allgather (send_buf .shape )
2549 (padded_send , padded_recv ) = _prepare_allgather_inputs (send_buf , send_shapes , engine = engine )
2650 recv_buffer_to_use = recv_buf if recv_buf else padded_recv
2751 base_comm .Allgather (padded_send , recv_buffer_to_use )
2852 return _unroll_allgather_recv (recv_buffer_to_use , padded_send .shape , send_shapes )
29-
3053 else :
3154 # CuPy with non-CUDA-aware MPI
3255 if recv_buf is None :
@@ -36,9 +59,11 @@ def mpi_allgather(base_comm: MPI.Comm,
3659
3760
3861def mpi_allreduce (base_comm : MPI .Comm ,
39- send_buf , recv_buf = None ,
40- engine : Optional [str ] = "numpy" ,
41- op : MPI .Op = MPI .SUM ) -> np .ndarray :
62+ send_buf : NDArray ,
63+ recv_buf : Optional [NDArray ] = None ,
64+ engine : str = "numpy" ,
65+ op : MPI .Op = MPI .SUM ,
66+ ) -> NDArray :
4267 """MPI_Allreduce/allreduce
4368
4469 Dispatch allreduce routine based on type of input and availability of
@@ -49,7 +74,7 @@ def mpi_allreduce(base_comm: MPI.Comm,
4974 base_comm : :obj:`MPI.Comm`
5075 Base MPI Communicator.
5176 send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
52- The data buffer from the local GPU to be reduced.
77+ The data buffer from the local rank to be reduced.
5378 recv_buf : :obj:`cupy.ndarray`, optional
5479 The buffer to store the result of the reduction. If None,
5580 a new buffer will be allocated with the appropriate shape.
@@ -62,7 +87,7 @@ def mpi_allreduce(base_comm: MPI.Comm,
6287 -------
6388 recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
6489 A buffer containing the result of the reduction, broadcasted
65- to all GPUs .
90+ to all ranks .
6691
6792 """
6893 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
@@ -80,8 +105,34 @@ def mpi_allreduce(base_comm: MPI.Comm,
80105
81106
82107def mpi_bcast (base_comm : MPI .Comm ,
83- rank , local_array , index , value ,
84- engine : Optional [str ] = "numpy" ) -> np .ndarray :
108+ rank : int ,
109+ local_array : NDArray ,
110+ index : int ,
111+ value : Union [int , NDArray ],
112+ engine : Optional [str ] = "numpy" ,
113+ ) -> None :
114+ """MPI_Bcast/bcast
115+
116+ Dispatch bcast routine based on type of input and availability of
117+ CUDA-Aware MPI
118+
119+ Parameters
120+ ----------
121+ base_comm : :obj:`MPI.Comm`
122+ Base MPI Communicator.
123+ rank : :obj:`int`
124+ Rank.
125+ local_array : :obj:`numpy.ndarray`
126+ Localy array to be broadcasted.
127+ index : :obj:`int` or :obj:`slice`
128+ Represents the index positions where a value needs to be assigned.
129+ value : :obj:`int` or :obj:`numpy.ndarray`
130+ Represents the value that will be assigned to the local array at
131+ the specified index positions.
132+ engine : :obj:`str`, optional
133+ Engine used to store array (``numpy`` or ``cupy``)
134+
135+ """
85136 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
86137 if rank == 0 :
87138 local_array [index ] = value
@@ -92,8 +143,11 @@ def mpi_bcast(base_comm: MPI.Comm,
92143
93144
94145def mpi_send (base_comm : MPI .Comm ,
95- send_buf , dest , count , tag = 0 ,
96- engine : Optional [str ] = "numpy" ,
146+ send_buf : NDArray ,
147+ dest : int ,
148+ count : Optional [int ] = None ,
149+ tag : int = 0 ,
150+ engine : str = "numpy" ,
97151 ) -> None :
98152 """MPI_Send/send
99153
@@ -114,6 +168,7 @@ def mpi_send(base_comm: MPI.Comm,
114168 Tag of the message to be sent.
115169 engine : :obj:`str`, optional
116170 Engine used to store array (``numpy`` or ``cupy``)
171+
117172 """
118173 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
119174 # Determine MPI type based on array dtype
@@ -127,8 +182,12 @@ def mpi_send(base_comm: MPI.Comm,
127182
128183
129184def mpi_recv (base_comm : MPI .Comm ,
130- recv_buf = None , source = 0 , count = None , tag = 0 ,
131- engine : Optional [str ] = "numpy" ) -> np .ndarray :
185+ recv_buf : Optional [NDArray ] = None ,
186+ source : int = 0 ,
187+ count : Optional [int ] = None ,
188+ tag : int = 0 ,
189+ engine : Optional [str ] = "numpy" ,
190+ ) -> NDArray :
132191 """ MPI_Recv/recv
133192 Dispatch receive routine based on type of input and availability of
134193 CUDA-Aware MPI
@@ -147,6 +206,12 @@ def mpi_recv(base_comm: MPI.Comm,
147206 Tag of the message to be sent.
148207 engine : :obj:`str`, optional
149208 Engine used to store array (``numpy`` or ``cupy``)
209+
210+ Returns
211+ -------
212+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
213+ The buffer containing the received data.
214+
150215 """
151216 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
152217 ncp = get_module (engine )
0 commit comments