@@ -24,7 +24,7 @@ class DistributedMixIn:
2424
2525    """ 
2626    def  _allreduce (self , base_comm , base_comm_nccl ,
27-                    send_buf , recv_buf = None ,  
27+                    send_buf , recv_buf = None ,
2828                   op : MPI .Op  =  MPI .SUM ,
2929                   engine = "numpy" ):
3030        """Allreduce operation 
@@ -44,13 +44,13 @@ def _allreduce(self, base_comm, base_comm_nccl,
4444            MPI operation to perform. 
4545        engine : :obj:`str`, optional 
4646            Engine used to store array (``numpy`` or ``cupy``) 
47-      
47+ 
4848        Returns 
4949        ------- 
5050        recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` 
5151            A buffer containing the result of the reduction, broadcasted 
5252            to all GPUs. 
53-          
53+ 
5454        """ 
5555        if  deps .nccl_enabled  and  base_comm_nccl  is  not None :
5656            return  nccl_allreduce (base_comm_nccl , send_buf , recv_buf , op )
@@ -59,7 +59,7 @@ def _allreduce(self, base_comm, base_comm_nccl,
5959                                 recv_buf , engine , op )
6060
6161    def  _allreduce_subcomm (self , sub_comm , base_comm_nccl ,
62-                            send_buf , recv_buf = None ,  
62+                            send_buf , recv_buf = None ,
6363                           op : MPI .Op  =  MPI .SUM ,
6464                           engine = "numpy" ):
6565        """Allreduce operation with subcommunicator 
@@ -79,13 +79,13 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl,
7979            MPI operation to perform. 
8080        engine : :obj:`str`, optional 
8181            Engine used to store array (``numpy`` or ``cupy``) 
82-      
82+ 
8383        Returns 
8484        ------- 
8585        recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` 
8686            A buffer containing the result of the reduction, broadcasted 
8787            to all ranks. 
88-          
88+ 
8989        """ 
9090        if  deps .nccl_enabled  and  base_comm_nccl  is  not None :
9191            return  nccl_allreduce (sub_comm , send_buf , recv_buf , op )
@@ -111,12 +111,12 @@ def _allgather(self, base_comm, base_comm_nccl,
111111            a new buffer will be allocated with the appropriate shape. 
112112        engine : :obj:`str`, optional 
113113            Engine used to store array (``numpy`` or ``cupy``) 
114-      
114+ 
115115        Returns 
116116        ------- 
117117        recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` 
118118            A buffer containing the gathered data from all ranks. 
119-          
119+ 
120120        """ 
121121        if  deps .nccl_enabled  and  base_comm_nccl  is  not None :
122122            if  isinstance (send_buf , (tuple , list , int )):
0 commit comments