@@ -75,7 +75,7 @@ def mpi_send(base_comm: MPI.Comm,
7575 send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
7676 The array containing data to send.
7777 dest: :obj:`int`
78- The rank of the destination GPU device.
78+ The rank of the destination CPU/ GPU device.
7979 count : :obj:`int`
8080 Number of elements to send from `send_buf`.
8181 tag : :obj:`int`
@@ -93,3 +93,43 @@ def mpi_send(base_comm: MPI.Comm,
9393 else :
9494 # Uses CuPy without CUDA-aware MPI
9595 base_comm .send (send_buf , dest , tag )
96+
97+ def mpi_recv (base_comm : MPI .Comm ,
98+ recv_buf = None , source = 0 , count = None , tag = 0 ,
99+ engine : Optional [str ] = "numpy" ) -> np .ndarray :
100+ """ MPI_Recv/recv
101+ Dispatch receive routine based on type of input and availability of
102+ CUDA-Aware MPI
103+
104+ Parameters
105+ ----------
106+ base_comm : :obj:`MPI.Comm`
107+ Base MPI Communicator.
108+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional
109+ The buffered array to receive data.
110+ source : :obj:`int`
111+ The rank of the sending CPU/GPU device.
112+ count : :obj:`int`
113+ Number of elements to receive.
114+ tag : :obj:`int`
115+ Tag of the message to be sent.
116+ engine : :obj:`str`, optional
117+ Engine used to store array (``numpy`` or ``cupy``)
118+
119+ """
120+ if deps .cuda_aware_mpi_enabled or engine == "numpy" :
121+ ncp = get_module (engine )
122+ if recv_buf is None :
123+ if count is None :
124+ raise ValueError ("Must provide either recv_buf or count for MPI receive" )
125+ # Default to int32 works currently because add_ghost_cells() is called
126+ # with recv_buf and is not affected by this branch. The int32 is for when
127+ # dimension or shape-related integers are send/recv
128+ recv_buf = ncp .zeros (count , dtype = ncp .int32 )
129+ mpi_type = MPI ._typedict [recv_buf .dtype .char ]
130+ base_comm .Recv ([recv_buf , recv_buf .size , mpi_type ], source = source , tag = tag )
131+ else :
132+ # Uses CuPy without CUDA-aware MPI
133+ recv_buf = base_comm .recv (source = source , tag = tag )
134+ return recv_buf
135+
0 commit comments