Skip to content

Commit ab97e3d

Browse files
committed
mpi_recv for MixIn
1 parent 838ed0b commit ab97e3d

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

pylops_mpi/utils/_mpi.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def mpi_send(base_comm: MPI.Comm,
7575
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
7676
The array containing data to send.
7777
dest: :obj:`int`
78-
The rank of the destination GPU device.
78+
The rank of the destination CPU/GPU device.
7979
count : :obj:`int`
8080
Number of elements to send from `send_buf`.
8181
tag : :obj:`int`
@@ -93,3 +93,43 @@ def mpi_send(base_comm: MPI.Comm,
9393
else:
9494
# Uses CuPy without CUDA-aware MPI
9595
base_comm.send(send_buf, dest, tag)
96+
97+
def mpi_recv(base_comm: MPI.Comm,
98+
recv_buf=None, source=0, count=None, tag=0,
99+
engine: Optional[str] = "numpy") -> np.ndarray:
100+
""" MPI_Recv/recv
101+
Dispatch receive routine based on type of input and availability of
102+
CUDA-Aware MPI
103+
104+
Parameters
105+
----------
106+
base_comm : :obj:`MPI.Comm`
107+
Base MPI Communicator.
108+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional
109+
The buffered array to receive data.
110+
source : :obj:`int`
111+
The rank of the sending CPU/GPU device.
112+
count : :obj:`int`
113+
Number of elements to receive.
114+
tag : :obj:`int`
115+
Tag of the message to be sent.
116+
engine : :obj:`str`, optional
117+
Engine used to store array (``numpy`` or ``cupy``)
118+
119+
"""
120+
if deps.cuda_aware_mpi_enabled or engine == "numpy":
121+
ncp = get_module(engine)
122+
if recv_buf is None:
123+
if count is None:
124+
raise ValueError("Must provide either recv_buf or count for MPI receive")
125+
# Default to int32 works currently because add_ghost_cells() is called
126+
# with recv_buf and is not affected by this branch. The int32 is for when
127+
# dimension or shape-related integers are send/recv
128+
recv_buf = ncp.zeros(count, dtype=ncp.int32)
129+
mpi_type = MPI._typedict[recv_buf.dtype.char]
130+
base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag)
131+
else:
132+
# Uses CuPy without CUDA-aware MPI
133+
recv_buf = base_comm.recv(source=source, tag=tag)
134+
return recv_buf
135+

0 commit comments

Comments
 (0)