Skip to content

Commit 64854bb

Browse files
committed
feat: added _mpi file with actual mpi comm. implementations
1 parent ca558fd commit 64854bb

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

pylops_mpi/utils/_mpi.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
__all__ = [
2+
# "mpi_allgather",
3+
"mpi_allreduce",
4+
# "mpi_bcast",
5+
# "mpi_asarray",
6+
"mpi_send",
7+
# "mpi_recv",
8+
]
9+
10+
from typing import Optional
11+
12+
import numpy as np
13+
from mpi4py import MPI
14+
from pylops.utils.backend import get_module
15+
from pylops_mpi.utils import deps
16+
17+
18+
def mpi_allreduce(base_comm: MPI.Comm,
19+
send_buf, recv_buf=None,
20+
engine: Optional[str] = "numpy",
21+
op: MPI.Op = MPI.SUM) -> np.ndarray:
22+
"""MPI_Allreduce/allreduce
23+
24+
Dispatch allreduce routine based on type of input and availability of
25+
CUDA-Aware MPI
26+
27+
Parameters
28+
----------
29+
base_comm : :obj:`MPI.Comm`
30+
Base MPI Communicator.
31+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
32+
The data buffer from the local GPU to be reduced.
33+
recv_buf : :obj:`cupy.ndarray`, optional
34+
The buffer to store the result of the reduction. If None,
35+
a new buffer will be allocated with the appropriate shape.
36+
engine : :obj:`str`, optional
37+
Engine used to store array (``numpy`` or ``cupy``)
38+
op : :obj:mpi4py.MPI.Op, optional
39+
The reduction operation to apply. Defaults to MPI.SUM.
40+
41+
Returns
42+
-------
43+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
44+
A buffer containing the result of the reduction, broadcasted
45+
to all GPUs.
46+
47+
"""
48+
if deps.cuda_aware_mpi_enabled or engine == "numpy":
49+
ncp = get_module(engine)
50+
recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
51+
base_comm.Allreduce(send_buf, recv_buf, op)
52+
return recv_buf
53+
else:
54+
# CuPy with non-CUDA-aware MPI
55+
if recv_buf is None:
56+
return base_comm.allreduce(send_buf, op)
57+
# For MIN and MAX which require recv_buf
58+
base_comm.Allreduce(send_buf, recv_buf, op)
59+
return recv_buf
60+
61+
62+
def mpi_send(base_comm: MPI.Comm,
63+
send_buf, dest, count, tag=0,
64+
engine: Optional[str] = "numpy",
65+
) -> None:
66+
"""MPI_Send/send
67+
68+
Dispatch send routine based on type of input and availability of
69+
CUDA-Aware MPI
70+
71+
Parameters
72+
----------
73+
base_comm : :obj:`MPI.Comm`
74+
Base MPI Communicator.
75+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
76+
The array containing data to send.
77+
dest: :obj:`int`
78+
The rank of the destination GPU device.
79+
count : :obj:`int`
80+
Number of elements to send from `send_buf`.
81+
tag : :obj:`int`
82+
Tag of the message to be sent.
83+
engine : :obj:`str`, optional
84+
Engine used to store array (``numpy`` or ``cupy``)
85+
86+
"""
87+
if deps.cuda_aware_mpi_enabled or engine == "numpy":
88+
# Determine MPI type based on array dtype
89+
mpi_type = MPI._typedict[send_buf.dtype.char]
90+
if count is None:
91+
count = send_buf.size
92+
base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag)
93+
else:
94+
# Uses CuPy without CUDA-aware MPI
95+
base_comm.send(send_buf, dest, tag)

0 commit comments

Comments
 (0)