99 "_unroll_allgather_recv"
1010]
1111
12- from typing import Optional , Tuple
12+ from typing import Optional
1313
1414import numpy as np
1515from mpi4py import MPI
1616from pylops .utils .backend import get_module
1717from pylops_mpi .utils import deps
1818
19+
1920# TODO: return type annotation for both cupy and numpy
2021def _prepare_allgather_inputs (send_buf , send_buf_shapes , engine ):
2122 r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
@@ -33,7 +34,7 @@ def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
3334 The data buffer from the local GPU to be sent for allgather.
3435 send_buf_shapes: :obj:`list`
3536 A list of shapes for each GPU send_buf (used to calculate padding size)
36- engine : :obj:`str`
37+ engine : :obj:`str`
3738 Engine used to store array (``numpy`` or ``cupy``)
3839
3940 Returns
@@ -96,20 +97,21 @@ def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) ->
9697
9798 return chunks
9899
100+
99101def mpi_allreduce (base_comm : MPI .Comm ,
100- send_buf , recv_buf = None ,
102+ send_buf , recv_buf = None ,
101103 engine : Optional [str ] = "numpy" ,
102104 op : MPI .Op = MPI .SUM ) -> np .ndarray :
103- """MPI_Allreduce/allreduce
104-
105- Dispatch allreduce routine based on type of input and availability of
105+ """MPI_Allreduce/allreduce
106+
107+ Dispatch allreduce routine based on type of input and availability of
106108 CUDA-Aware MPI
107109
108110 Parameters
109111 ----------
110112 base_comm : :obj:`MPI.Comm`
111113 Base MPI Communicator.
112- send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
114+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
113115 The data buffer from the local GPU to be reduced.
114116 recv_buf : :obj:`cupy.ndarray`, optional
115117 The buffer to store the result of the reduction. If None,
@@ -121,10 +123,10 @@ def mpi_allreduce(base_comm: MPI.Comm,
121123
122124 Returns
123125 -------
124- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
126+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
125127 A buffer containing the result of the reduction, broadcasted
126128 to all GPUs.
127-
129+
128130 """
129131 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
130132 ncp = get_module (engine )
@@ -141,9 +143,8 @@ def mpi_allreduce(base_comm: MPI.Comm,
141143
142144
143145def mpi_allgather (base_comm : MPI .Comm ,
144- send_buf , recv_buf = None ,
145- engine : Optional [str ] = "numpy" ,
146- ) -> np .ndarray :
146+ send_buf , recv_buf = None ,
147+ engine : Optional [str ] = "numpy" ) -> np .ndarray :
147148
148149 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
149150 send_shapes = base_comm .allgather (send_buf .shape )
@@ -165,15 +166,15 @@ def mpi_send(base_comm: MPI.Comm,
165166 engine : Optional [str ] = "numpy" ,
166167 ) -> None :
167168 """MPI_Send/send
168-
169- Dispatch send routine based on type of input and availability of
169+
170+ Dispatch send routine based on type of input and availability of
170171 CUDA-Aware MPI
171172
172173 Parameters
173174 ----------
174175 base_comm : :obj:`MPI.Comm`
175176 Base MPI Communicator.
176- send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
177+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
177178 The array containing data to send.
178179 dest: :obj:`int`
179180 The rank of the destination CPU/GPU device.
@@ -183,7 +184,6 @@ def mpi_send(base_comm: MPI.Comm,
183184 Tag of the message to be sent.
184185 engine : :obj:`str`, optional
185186 Engine used to store array (``numpy`` or ``cupy``)
186-
187187 """
188188 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
189189 # Determine MPI type based on array dtype
@@ -195,11 +195,12 @@ def mpi_send(base_comm: MPI.Comm,
195195 # Uses CuPy without CUDA-aware MPI
196196 base_comm .send (send_buf , dest , tag )
197197
198+
198199def mpi_recv (base_comm : MPI .Comm ,
199- recv_buf = None , source = 0 , count = None , tag = 0 ,
200- engine : Optional [str ] = "numpy" ) -> np .ndarray :
200+ recv_buf = None , source = 0 , count = None , tag = 0 ,
201+ engine : Optional [str ] = "numpy" ) -> np .ndarray :
201202 """ MPI_Recv/recv
202- Dispatch receive routine based on type of input and availability of
203+ Dispatch receive routine based on type of input and availability of
203204 CUDA-Aware MPI
204205
205206 Parameters
@@ -216,7 +217,6 @@ def mpi_recv(base_comm: MPI.Comm,
216217 Tag of the message to be sent.
217218 engine : :obj:`str`, optional
218219 Engine used to store array (``numpy`` or ``cupy``)
219-
220220 """
221221 if deps .cuda_aware_mpi_enabled or engine == "numpy" :
222222 ncp = get_module (engine )
@@ -233,4 +233,3 @@ def mpi_recv(base_comm: MPI.Comm,
233233 # Uses CuPy without CUDA-aware MPI
234234 recv_buf = base_comm .recv (source = source , tag = tag )
235235 return recv_buf
236-
0 commit comments