@@ -16,16 +16,41 @@ class DistributedMixIn:
1616 r"""Distributed Mixin class
1717
1818 This class implements all methods associated with communication primitives
19- from MPI and NCCL. It is mostly charged to identifying which commuicator
19+ from MPI and NCCL. It is mostly charged with identifying which commuicator
2020 to use and whether the buffered or object MPI primitives should be used
2121 (the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware
2222 MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
2323 MPI installation is not available).
24+
2425 """
2526 def _allreduce (self , base_comm , base_comm_nccl ,
26- send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ,
27+ send_buf , recv_buf = None ,
28+ op : MPI .Op = MPI .SUM ,
2729 engine = "numpy" ):
2830 """Allreduce operation
31+
32+ Parameters
33+ ----------
34+ base_comm : :obj:`MPI.Comm`
35+ Base MPI Communicator.
36+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
37+ NCCL Communicator.
38+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
39+ A buffer containing the data to be sent by this rank.
40+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
41+ The buffer to store the result of the reduction. If None,
42+ a new buffer will be allocated with the appropriate shape.
43+ op : :obj: `MPI.Op`, optional
44+ MPI operation to perform.
45+ engine : :obj:`str`, optional
46+ Engine used to store array (``numpy`` or ``cupy``)
47+
48+ Returns
49+ -------
50+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
51+ A buffer containing the result of the reduction, broadcasted
52+ to all GPUs.
53+
2954 """
3055 if deps .nccl_enabled and base_comm_nccl is not None :
3156 return nccl_allreduce (base_comm_nccl , send_buf , recv_buf , op )
@@ -34,9 +59,33 @@ def _allreduce(self, base_comm, base_comm_nccl,
3459 recv_buf , engine , op )
3560
3661 def _allreduce_subcomm (self , sub_comm , base_comm_nccl ,
37- send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ,
62+ send_buf , recv_buf = None ,
63+ op : MPI .Op = MPI .SUM ,
3864 engine = "numpy" ):
3965 """Allreduce operation with subcommunicator
66+
67+ Parameters
68+ ----------
69+ sub_comm : :obj:`MPI.Comm`
70+ MPI Subcommunicator.
71+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
72+ NCCL Communicator.
73+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
74+ A buffer containing the data to be sent by this rank.
75+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
76+ The buffer to store the result of the reduction. If None,
77+ a new buffer will be allocated with the appropriate shape.
78+ op : :obj: `MPI.Op`, optional
79+ MPI operation to perform.
80+ engine : :obj:`str`, optional
81+ Engine used to store array (``numpy`` or ``cupy``)
82+
83+ Returns
84+ -------
85+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
86+ A buffer containing the result of the reduction, broadcasted
87+ to all ranks.
88+
4089 """
4190 if deps .nccl_enabled and base_comm_nccl is not None :
4291 return nccl_allreduce (sub_comm , send_buf , recv_buf , op )
@@ -48,6 +97,26 @@ def _allgather(self, base_comm, base_comm_nccl,
4897 send_buf , recv_buf = None ,
4998 engine = "numpy" ):
5099 """Allgather operation
100+
101+ Parameters
102+ ----------
103+ sub_comm : :obj:`MPI.Comm`
104+ MPI Subcommunicator.
105+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
106+ NCCL Communicator.
107+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
108+ A buffer containing the data to be sent by this rank.
109+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
110+ The buffer to store the result of the gathering. If None,
111+ a new buffer will be allocated with the appropriate shape.
112+ engine : :obj:`str`, optional
113+ Engine used to store array (``numpy`` or ``cupy``)
114+
115+ Returns
116+ -------
117+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
118+ A buffer containing the gathered data from all ranks.
119+
51120 """
52121 if deps .nccl_enabled and base_comm_nccl is not None :
53122 if isinstance (send_buf , (tuple , list , int )):
0 commit comments