1+ from typing import Any , NewType , Optional , Union
2+
13from mpi4py import MPI
4+ from pylops .utils import NDArray
25from pylops .utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
36from pylops_mpi .utils ._mpi import mpi_allreduce , mpi_allgather , mpi_bcast , mpi_send , mpi_recv , _prepare_allgather_inputs , _unroll_allgather_recv
47from pylops_mpi .utils import deps
1013 from pylops_mpi .utils ._nccl import (
1114 nccl_allgather , nccl_allreduce , nccl_bcast , nccl_send , nccl_recv
1215 )
16+ from cupy .cuda .nccl import NcclCommunicator
17+ else :
18+ NcclCommunicator = Any
19+
20+ NcclCommunicatorType = NewType ("NcclCommunicator" , NcclCommunicator )
1321
1422
1523class DistributedMixIn :
@@ -23,10 +31,14 @@ class DistributedMixIn:
2331 MPI installation is not available).
2432
2533 """
26- def _allreduce (self , base_comm , base_comm_nccl ,
27- send_buf , recv_buf = None ,
34+ def _allreduce (self ,
35+ base_comm : MPI .Comm ,
36+ base_comm_nccl : NcclCommunicatorType ,
37+ send_buf : NDArray ,
38+ recv_buf : Optional [NDArray ] = None ,
2839 op : MPI .Op = MPI .SUM ,
29- engine = "numpy" ):
40+ engine : str = "numpy" ,
41+ ) -> NDArray :
3042 """Allreduce operation
3143
3244 Parameters
@@ -58,10 +70,14 @@ def _allreduce(self, base_comm, base_comm_nccl,
5870 return mpi_allreduce (base_comm , send_buf ,
5971 recv_buf , engine , op )
6072
61- def _allreduce_subcomm (self , sub_comm , base_comm_nccl ,
62- send_buf , recv_buf = None ,
73+ def _allreduce_subcomm (self ,
74+ sub_comm : MPI .Comm ,
75+ base_comm_nccl : NcclCommunicatorType ,
76+ send_buf : NDArray ,
77+ recv_buf : Optional [NDArray ] = None ,
6378 op : MPI .Op = MPI .SUM ,
64- engine = "numpy" ):
79+ engine : str = "numpy" ,
80+ ) -> NDArray :
6581 """Allreduce operation with subcommunicator
6682
6783 Parameters
@@ -93,15 +109,19 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
93109 return mpi_allreduce (sub_comm , send_buf ,
94110 recv_buf , engine , op )
95111
96- def _allgather (self , base_comm , base_comm_nccl ,
97- send_buf , recv_buf = None ,
98- engine = "numpy" ):
112+ def _allgather (self ,
113+ base_comm : MPI .Comm ,
114+ base_comm_nccl : NcclCommunicatorType ,
115+ send_buf : NDArray ,
116+ recv_buf : Optional [NDArray ] = None ,
117+ engine : str = "numpy" ,
118+ ) -> NDArray :
99119 """Allgather operation
100120
101121 Parameters
102122 ----------
103- sub_comm : :obj:`MPI.Comm`
104- MPI Subcommunicator .
123+ base_comm : :obj:`MPI.Comm`
124+ Base MPI Communicator .
105125 base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
106126 NCCL Communicator.
107127 send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
@@ -131,41 +151,119 @@ def _allgather(self, base_comm, base_comm_nccl,
131151 return base_comm .allgather (send_buf )
132152 return mpi_allgather (base_comm , send_buf , recv_buf , engine )
133153
134- def _allgather_subcomm (self , send_buf , recv_buf = None ):
154+ def _allgather_subcomm (self ,
155+ sub_comm : MPI .Comm ,
156+ base_comm_nccl : NcclCommunicatorType ,
157+ send_buf : NDArray ,
158+ recv_buf : Optional [NDArray ] = None ,
159+ engine : str = "numpy" ,
160+ ) -> NDArray :
135161 """Allgather operation with subcommunicator
162+
163+ Parameters
164+ ----------
165+ sub_comm : :obj:`MPI.Comm`
166+ MPI Subcommunicator.
167+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
168+ NCCL Communicator.
169+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
170+ A buffer containing the data to be sent by this rank.
171+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
172+ The buffer to store the result of the gathering. If None,
173+ a new buffer will be allocated with the appropriate shape.
174+ engine : :obj:`str`, optional
175+ Engine used to store array (``numpy`` or ``cupy``)
176+
177+ Returns
178+ -------
179+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
180+ A buffer containing the gathered data from all ranks.
181+
136182 """
137- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
183+ if deps .nccl_enabled and base_comm_nccl is not None :
138184 if isinstance (send_buf , (tuple , list , int )):
139- return nccl_allgather (self . sub_comm , send_buf , recv_buf )
185+ return nccl_allgather (sub_comm , send_buf , recv_buf )
140186 else :
141- send_shapes = self ._allgather_subcomm (send_buf .shape )
187+ send_shapes = sub_comm ._allgather_subcomm (send_buf .shape )
142188 (padded_send , padded_recv ) = _prepare_allgather_inputs (send_buf , send_shapes , engine = "cupy" )
143- raw_recv = nccl_allgather (self . sub_comm , padded_send , recv_buf if recv_buf else padded_recv )
189+ raw_recv = nccl_allgather (sub_comm , padded_send , recv_buf if recv_buf else padded_recv )
144190 return _unroll_allgather_recv (raw_recv , padded_send .shape , send_shapes )
145191 else :
146- return mpi_allgather (self . sub_comm , send_buf , recv_buf , self . engine )
192+ return mpi_allgather (sub_comm , send_buf , recv_buf , engine )
147193
148- def _bcast (self , local_array , index , value ):
194+ def _bcast (self ,
195+ base_comm : MPI .Comm ,
196+ base_comm_nccl : NcclCommunicatorType ,
197+ rank : int ,
198+ local_array : NDArray ,
199+ index : int ,
200+ value : Union [int , NDArray ],
201+ engine : str = "numpy" ,
202+ ) -> None :
149203 """BCast operation
204+
205+ Parameters
206+ ----------
207+ base_comm : :obj:`MPI.Comm`
208+ Base MPI Communicator.
209+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
210+ NCCL Communicator.
211+ rank : :obj:`int`
212+ Rank.
213+ local_array : :obj:`numpy.ndarray`
214+ Localy array to be broadcasted.
215+ index : :obj:`int` or :obj:`slice`
216+ Represents the index positions where a value needs to be assigned.
217+ value : :obj:`int` or :obj:`numpy.ndarray`
218+ Represents the value that will be assigned to the local array at
219+ the specified index positions.
220+ engine : :obj:`str`, optional
221+ Engine used to store array (``numpy`` or ``cupy``)
222+
150223 """
151- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
152- nccl_bcast (self . base_comm_nccl , local_array , index , value )
224+ if deps .nccl_enabled and base_comm_nccl is not None :
225+ nccl_bcast (base_comm_nccl , local_array , index , value )
153226 else :
154- # self.local_array[index] = self.base_comm.bcast(value)
155- mpi_bcast (self .base_comm , self .rank , self .local_array , index , value ,
156- engine = self .engine )
227+ mpi_bcast (base_comm , rank , local_array , index , value ,
228+ engine = engine )
157229
158- def _send (self , send_buf , dest , count = None , tag = 0 ):
230+ def _send (self ,
231+ base_comm : MPI .Comm ,
232+ base_comm_nccl : NcclCommunicatorType ,
233+ send_buf : NDArray ,
234+ dest : int ,
235+ count : Optional [int ] = None ,
236+ tag : int = 0 ,
237+ engine : str = "numpy" ,
238+ ) -> None :
159239 """Send operation
240+
241+ Parameters
242+ ----------
243+ base_comm : :obj:`MPI.Comm`
244+ Base MPI Communicator.
245+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
246+ NCCL Communicator.
247+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
248+ The array containing data to send.
249+ dest: :obj:`int`
250+ The rank of the destination.
251+ count : :obj:`int`
252+ Number of elements to send from `send_buf`.
253+ tag : :obj:`int`
254+ Tag of the message to be sent.
255+ engine : :obj:`str`, optional
256+ Engine used to store array (``numpy`` or ``cupy``)
257+
160258 """
161- if deps .nccl_enabled and self . base_comm_nccl :
259+ if deps .nccl_enabled and base_comm_nccl is not None :
162260 if count is None :
163261 count = send_buf .size
164- nccl_send (self . base_comm_nccl , send_buf , dest , count )
262+ nccl_send (base_comm_nccl , send_buf , dest , count )
165263 else :
166- mpi_send (self . base_comm ,
264+ mpi_send (base_comm ,
167265 send_buf , dest , count , tag = tag ,
168- engine = self . engine )
266+ engine = engine )
169267
170268 def _recv (self , recv_buf = None , source = 0 , count = None , tag = 0 ):
171269 """Receive operation
0 commit comments