Skip to content

Commit e33e8f1

Browse files
committed
doc: added documentation and type hints to _mpi
1 parent 2c67755 commit e33e8f1

File tree

1 file changed

+81
-16
lines changed

1 file changed

+81
-16
lines changed

pylops_mpi/utils/_mpi.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,54 @@
22
"mpi_allgather",
33
"mpi_allreduce",
44
"mpi_bcast",
5-
# "mpi_asarray",
65
"mpi_send",
76
"mpi_recv",
87
]
98

10-
from typing import Optional
9+
from typing import Optional, Union
1110

1211
import numpy as np
1312
from mpi4py import MPI
13+
from pylops.utils import NDArray
1414
from pylops.utils.backend import get_module
1515
from pylops_mpi.utils import deps
1616
from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
1717

1818

1919
def mpi_allgather(base_comm: MPI.Comm,
20-
send_buf, recv_buf=None,
21-
engine: Optional[str] = "numpy") -> np.ndarray:
20+
send_buf: NDArray,
21+
recv_buf: Optional[NDArray] = None,
22+
engine: str = "numpy",
23+
) -> NDArray:
24+
"""MPI_Allallgather/allallgather
2225
26+
Dispatch allgather routine based on type of input and availability of
27+
CUDA-Aware MPI
28+
29+
Parameters
30+
----------
31+
base_comm : :obj:`MPI.Comm`
32+
Base MPI Communicator.
33+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
34+
The data buffer from the local rank to be gathered.
35+
recv_buf : :obj:`cupy.ndarray`, optional
36+
The buffer to store the result of the gathering. If None,
37+
a new buffer will be allocated with the appropriate shape.
38+
engine : :obj:`str`, optional
39+
Engine used to store array (``numpy`` or ``cupy``)
40+
41+
Returns
42+
-------
43+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
44+
A buffer containing the gathered data from all ranks.
45+
46+
"""
2347
if deps.cuda_aware_mpi_enabled or engine == "numpy":
2448
send_shapes = base_comm.allgather(send_buf.shape)
2549
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine)
2650
recv_buffer_to_use = recv_buf if recv_buf else padded_recv
2751
base_comm.Allgather(padded_send, recv_buffer_to_use)
2852
return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes)
29-
3053
else:
3154
# CuPy with non-CUDA-aware MPI
3255
if recv_buf is None:
@@ -36,9 +59,11 @@ def mpi_allgather(base_comm: MPI.Comm,
3659

3760

3861
def mpi_allreduce(base_comm: MPI.Comm,
39-
send_buf, recv_buf=None,
40-
engine: Optional[str] = "numpy",
41-
op: MPI.Op = MPI.SUM) -> np.ndarray:
62+
send_buf: NDArray,
63+
recv_buf: Optional[NDArray] = None,
64+
engine: str = "numpy",
65+
op: MPI.Op = MPI.SUM,
66+
) -> NDArray:
4267
"""MPI_Allreduce/allreduce
4368
4469
Dispatch allreduce routine based on type of input and availability of
@@ -49,7 +74,7 @@ def mpi_allreduce(base_comm: MPI.Comm,
4974
base_comm : :obj:`MPI.Comm`
5075
Base MPI Communicator.
5176
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
52-
The data buffer from the local GPU to be reduced.
77+
The data buffer from the local rank to be reduced.
5378
recv_buf : :obj:`cupy.ndarray`, optional
5479
The buffer to store the result of the reduction. If None,
5580
a new buffer will be allocated with the appropriate shape.
@@ -62,7 +87,7 @@ def mpi_allreduce(base_comm: MPI.Comm,
6287
-------
6388
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
6489
A buffer containing the result of the reduction, broadcasted
65-
to all GPUs.
90+
to all ranks.
6691
6792
"""
6893
if deps.cuda_aware_mpi_enabled or engine == "numpy":
@@ -80,8 +105,34 @@ def mpi_allreduce(base_comm: MPI.Comm,
80105

81106

82107
def mpi_bcast(base_comm: MPI.Comm,
83-
rank, local_array, index, value,
84-
engine: Optional[str] = "numpy") -> np.ndarray:
108+
rank: int,
109+
local_array: NDArray,
110+
index: int,
111+
value: Union[int, NDArray],
112+
engine: Optional[str] = "numpy",
113+
) -> None:
114+
"""MPI_Bcast/bcast
115+
116+
Dispatch bcast routine based on type of input and availability of
117+
CUDA-Aware MPI
118+
119+
Parameters
120+
----------
121+
base_comm : :obj:`MPI.Comm`
122+
Base MPI Communicator.
123+
rank : :obj:`int`
124+
Rank.
125+
local_array : :obj:`numpy.ndarray`
126+
Localy array to be broadcasted.
127+
index : :obj:`int` or :obj:`slice`
128+
Represents the index positions where a value needs to be assigned.
129+
value : :obj:`int` or :obj:`numpy.ndarray`
130+
Represents the value that will be assigned to the local array at
131+
the specified index positions.
132+
engine : :obj:`str`, optional
133+
Engine used to store array (``numpy`` or ``cupy``)
134+
135+
"""
85136
if deps.cuda_aware_mpi_enabled or engine == "numpy":
86137
if rank == 0:
87138
local_array[index] = value
@@ -92,8 +143,11 @@ def mpi_bcast(base_comm: MPI.Comm,
92143

93144

94145
def mpi_send(base_comm: MPI.Comm,
95-
send_buf, dest, count, tag=0,
96-
engine: Optional[str] = "numpy",
146+
send_buf: NDArray,
147+
dest: int,
148+
count: Optional[int] = None,
149+
tag: int = 0,
150+
engine: str = "numpy",
97151
) -> None:
98152
"""MPI_Send/send
99153
@@ -114,6 +168,7 @@ def mpi_send(base_comm: MPI.Comm,
114168
Tag of the message to be sent.
115169
engine : :obj:`str`, optional
116170
Engine used to store array (``numpy`` or ``cupy``)
171+
117172
"""
118173
if deps.cuda_aware_mpi_enabled or engine == "numpy":
119174
# Determine MPI type based on array dtype
@@ -127,8 +182,12 @@ def mpi_send(base_comm: MPI.Comm,
127182

128183

129184
def mpi_recv(base_comm: MPI.Comm,
130-
recv_buf=None, source=0, count=None, tag=0,
131-
engine: Optional[str] = "numpy") -> np.ndarray:
185+
recv_buf: Optional[NDArray] = None,
186+
source: int = 0,
187+
count: Optional[int] = None,
188+
tag: int = 0,
189+
engine: Optional[str] = "numpy",
190+
) -> NDArray:
132191
""" MPI_Recv/recv
133192
Dispatch receive routine based on type of input and availability of
134193
CUDA-Aware MPI
@@ -147,6 +206,12 @@ def mpi_recv(base_comm: MPI.Comm,
147206
Tag of the message to be sent.
148207
engine : :obj:`str`, optional
149208
Engine used to store array (``numpy`` or ``cupy``)
209+
210+
Returns
211+
-------
212+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
213+
The buffer containing the received data.
214+
150215
"""
151216
if deps.cuda_aware_mpi_enabled or engine == "numpy":
152217
ncp = get_module(engine)

0 commit comments

Comments
 (0)