Skip to content

Commit 2cdb8f7

Browse files
committed
doc: added some docstrings to Distributed
1 parent 473cd97 commit 2cdb8f7

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

pylops_mpi/Distributed.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,41 @@ class DistributedMixIn:
1616
r"""Distributed Mixin class
1717
1818
This class implements all methods associated with communication primitives
19-
from MPI and NCCL. It is mostly charged to identifying which commuicator
19+
from MPI and NCCL. It is mostly charged with identifying which commuicator
2020
to use and whether the buffered or object MPI primitives should be used
2121
(the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware
2222
MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
2323
MPI installation is not available).
24+
2425
"""
2526
def _allreduce(self, base_comm, base_comm_nccl,
26-
send_buf, recv_buf=None, op: MPI.Op = MPI.SUM,
27+
send_buf, recv_buf=None,
28+
op: MPI.Op = MPI.SUM,
2729
engine="numpy"):
2830
"""Allreduce operation
31+
32+
Parameters
33+
----------
34+
base_comm : :obj:`MPI.Comm`
35+
Base MPI Communicator.
36+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
37+
NCCL Communicator.
38+
send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
39+
A buffer containing the data to be sent by this rank.
40+
recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
41+
The buffer to store the result of the reduction. If None,
42+
a new buffer will be allocated with the appropriate shape.
43+
op : :obj: `MPI.Op`, optional
44+
MPI operation to perform.
45+
engine : :obj:`str`, optional
46+
Engine used to store array (``numpy`` or ``cupy``)
47+
48+
Returns
49+
-------
50+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
51+
A buffer containing the result of the reduction, broadcasted
52+
to all GPUs.
53+
2954
"""
3055
if deps.nccl_enabled and base_comm_nccl is not None:
3156
return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op)
@@ -34,9 +59,33 @@ def _allreduce(self, base_comm, base_comm_nccl,
3459
recv_buf, engine, op)
3560

3661
def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
37-
send_buf, recv_buf=None, op: MPI.Op = MPI.SUM,
62+
send_buf, recv_buf=None,
63+
op: MPI.Op = MPI.SUM,
3864
engine="numpy"):
3965
"""Allreduce operation with subcommunicator
66+
67+
Parameters
68+
----------
69+
sub_comm : :obj:`MPI.Comm`
70+
MPI Subcommunicator.
71+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
72+
NCCL Communicator.
73+
send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
74+
A buffer containing the data to be sent by this rank.
75+
recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
76+
The buffer to store the result of the reduction. If None,
77+
a new buffer will be allocated with the appropriate shape.
78+
op : :obj: `MPI.Op`, optional
79+
MPI operation to perform.
80+
engine : :obj:`str`, optional
81+
Engine used to store array (``numpy`` or ``cupy``)
82+
83+
Returns
84+
-------
85+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
86+
A buffer containing the result of the reduction, broadcasted
87+
to all ranks.
88+
4089
"""
4190
if deps.nccl_enabled and base_comm_nccl is not None:
4291
return nccl_allreduce(sub_comm, send_buf, recv_buf, op)
@@ -48,6 +97,26 @@ def _allgather(self, base_comm, base_comm_nccl,
4897
send_buf, recv_buf=None,
4998
engine="numpy"):
5099
"""Allgather operation
100+
101+
Parameters
102+
----------
103+
sub_comm : :obj:`MPI.Comm`
104+
MPI Subcommunicator.
105+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
106+
NCCL Communicator.
107+
send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
108+
A buffer containing the data to be sent by this rank.
109+
recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
110+
The buffer to store the result of the gathering. If None,
111+
a new buffer will be allocated with the appropriate shape.
112+
engine : :obj:`str`, optional
113+
Engine used to store array (``numpy`` or ``cupy``)
114+
115+
Returns
116+
-------
117+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
118+
A buffer containing the gathered data from all ranks.
119+
51120
"""
52121
if deps.nccl_enabled and base_comm_nccl is not None:
53122
if isinstance(send_buf, (tuple, list, int)):

0 commit comments

Comments
 (0)