@@ -24,7 +24,7 @@ class DistributedMixIn:
2424
2525 """
2626 def _allreduce (self , base_comm , base_comm_nccl ,
27- send_buf , recv_buf = None ,
27+ send_buf , recv_buf = None ,
2828 op : MPI .Op = MPI .SUM ,
2929 engine = "numpy" ):
3030 """Allreduce operation
@@ -44,13 +44,13 @@ def _allreduce(self, base_comm, base_comm_nccl,
4444 MPI operation to perform.
4545 engine : :obj:`str`, optional
4646 Engine used to store array (``numpy`` or ``cupy``)
47-
47+
4848 Returns
4949 -------
5050 recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
5151 A buffer containing the result of the reduction, broadcasted
5252 to all GPUs.
53-
53+
5454 """
5555 if deps .nccl_enabled and base_comm_nccl is not None :
5656 return nccl_allreduce (base_comm_nccl , send_buf , recv_buf , op )
@@ -59,7 +59,7 @@ def _allreduce(self, base_comm, base_comm_nccl,
5959 recv_buf , engine , op )
6060
6161 def _allreduce_subcomm (self , sub_comm , base_comm_nccl ,
62- send_buf , recv_buf = None ,
62+ send_buf , recv_buf = None ,
6363 op : MPI .Op = MPI .SUM ,
6464 engine = "numpy" ):
6565 """Allreduce operation with subcommunicator
@@ -79,13 +79,13 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
7979 MPI operation to perform.
8080 engine : :obj:`str`, optional
8181 Engine used to store array (``numpy`` or ``cupy``)
82-
82+
8383 Returns
8484 -------
8585 recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
8686 A buffer containing the result of the reduction, broadcasted
8787 to all ranks.
88-
88+
8989 """
9090 if deps .nccl_enabled and base_comm_nccl is not None :
9191 return nccl_allreduce (sub_comm , send_buf , recv_buf , op )
@@ -111,12 +111,12 @@ def _allgather(self, base_comm, base_comm_nccl,
111111 a new buffer will be allocated with the appropriate shape.
112112 engine : :obj:`str`, optional
113113 Engine used to store array (``numpy`` or ``cupy``)
114-
114+
115115 Returns
116116 -------
117117 recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
118118 A buffer containing the gathered data from all ranks.
119-
119+
120120 """
121121 if deps .nccl_enabled and base_comm_nccl is not None :
122122 if isinstance (send_buf , (tuple , list , int )):
0 commit comments