@@ -265,17 +265,45 @@ def _send(self,
265265 send_buf , dest , count , tag = tag ,
266266 engine = engine )
267267
268- def _recv (self , recv_buf = None , source = 0 , count = None , tag = 0 ):
268+ def _recv (self ,
269+ base_comm : MPI .Comm ,
270+ base_comm_nccl : NcclCommunicatorType ,
271+ recv_buf = None , source = 0 , count = None , tag = 0 ,
272+ engine : str = "numpy" ,
273+ ) -> NDArray :
269274 """Receive operation
275+
276+ Parameters
277+ ----------
278+ base_comm : :obj:`MPI.Comm`
279+ Base MPI Communicator.
280+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
281+ NCCL Communicator.
282+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional
283+ The buffered array to receive data.
284+ source : :obj:`int`
285+ The rank of the sending CPU/GPU device.
286+ count : :obj:`int`
287+ Number of elements to receive.
288+ tag : :obj:`int`
289+ Tag of the message to be sent.
290+ engine : :obj:`str`, optional
291+ Engine used to store array (``numpy`` or ``cupy``)
292+
293+ Returns
294+ -------
295+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
296+ The buffer containing the received data.
297+
270298 """
271- if deps .nccl_enabled and self . base_comm_nccl :
299+ if deps .nccl_enabled and base_comm_nccl is not None :
272300 if recv_buf is None :
273301 raise ValueError ("recv_buf must be supplied when using NCCL" )
274302 if count is None :
275303 count = recv_buf .size
276- nccl_recv (self . base_comm_nccl , recv_buf , source , count )
304+ nccl_recv (base_comm_nccl , recv_buf , source , count )
277305 return recv_buf
278306 else :
279- return mpi_recv (self . base_comm ,
307+ return mpi_recv (base_comm ,
280308 recv_buf , source , count , tag = tag ,
281- engine = self . engine )
309+ engine = engine )
0 commit comments