@@ -483,23 +483,39 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
483483        if  deps .nccl_enabled  and  getattr (self , "base_comm_nccl" ):
484484            return  nccl_allreduce (self .base_comm_nccl , send_buf , recv_buf , op )
485485        else :
486-             if  recv_buf  is  None :
487-                 return  self .base_comm .allreduce (send_buf , op )
488-             # For MIN and MAX which require recv_buf 
489-             self .base_comm .Allreduce (send_buf , recv_buf , op )
490-             return  recv_buf 
486+             if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
487+                 ncp  =  get_module (self .engine )
488+                 # mpi_type = MPI._typedict[send_buf.dtype.char] 
489+                 recv_buf  =  ncp .zeros (send_buf .size , dtype = send_buf .dtype )
490+                 self .base_comm .Allreduce (send_buf , recv_buf , op )
491+                 return  recv_buf 
492+             else :
493+                 # CuPy with non-CUDA-aware MPI 
494+                 if  recv_buf  is  None :
495+                     return  self .base_comm .allreduce (send_buf , op )
496+                 # For MIN and MAX which require recv_buf 
497+                 self .base_comm .Allreduce (send_buf , recv_buf , op )
498+                 return  recv_buf 
491499
492500    def  _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op  =  MPI .SUM ):
493501        """Allreduce operation with subcommunicator 
494502        """ 
495503        if  deps .nccl_enabled  and  getattr (self , "base_comm_nccl" ):
496504            return  nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
497505        else :
498-             if  recv_buf  is  None :
499-                 return  self .sub_comm .allreduce (send_buf , op )
500-             # For MIN and MAX which require recv_buf 
501-             self .sub_comm .Allreduce (send_buf , recv_buf , op )
502-             return  recv_buf 
506+             if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
507+                 ncp  =  get_module (self .engine )
508+                 # mpi_type = MPI._typedict[send_buf.dtype.char] 
509+                 recv_buf  =  ncp .zeros (send_buf .size , dtype = send_buf .dtype )
510+                 self .sub_comm .Allreduce (send_buf , recv_buf , op )
511+                 return  recv_buf 
512+             else :
513+                 # CuPy with non-CUDA-aware MPI 
514+                 if  recv_buf  is  None :
515+                     return  self .sub_comm .allreduce (send_buf , op )
516+                 # For MIN and MAX which require recv_buf 
517+                 self .sub_comm .Allreduce (send_buf , recv_buf , op )
518+                 return  recv_buf 
503519
504520    def  _allgather (self , send_buf , recv_buf = None ):
505521        """Allgather operation 
@@ -717,26 +733,29 @@ def _compute_vector_norm(self, local_array: NDArray,
717733            recv_buf  =  self ._allreduce_subcomm (ncp .count_nonzero (local_array , axis = axis ).astype (ncp .float64 ))
718734        elif  ord  ==  ncp .inf :
719735            # Calculate max followed by max reduction 
720-             # TODO (tharitt): currently  CuPy + MPI does not work well with buffered communication, particularly 
736+             # CuPy + non-CUDA-aware  MPI does not work well with buffered communication, particularly 
721737            # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs 
722738            send_buf  =  ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
723-             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None :
739+             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None  and  not  is_cuda_aware_mpi :
740+                 # CuPy + non-CUDA-aware MPI: This will call non-buffered communication  
741+                 # which return a list of object - must be copied back to a GPU memory. 
724742                recv_buf  =  self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
725743                recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
726744            else :
727745                recv_buf  =  self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
728-                 recv_buf  =  ncp .squeeze (recv_buf , axis = axis )
746+                 if  self .base_comm_nccl :
747+                     recv_buf  =  ncp .squeeze (recv_buf , axis = axis )
729748        elif  ord  ==  - ncp .inf :
730749            # Calculate min followed by min reduction 
731-             # TODO (tharitt): see  the comment above in infinity norm 
750+             # See  the comment above in + infinity norm 
732751            send_buf  =  ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
733-             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None :
752+             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None   and   not   is_cuda_aware_mpi :
734753                recv_buf  =  self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
735754                recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
736755            else :
737756                recv_buf  =  self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MIN )
738-                 recv_buf   =   ncp . asarray ( ncp . squeeze ( recv_buf ,  axis = axis )) 
739- 
757+                 if   self . base_comm_nccl : 
758+                      recv_buf   =   ncp . asarray ( ncp . squeeze ( recv_buf ,  axis = axis )) 
740759        else :
741760            recv_buf  =  self ._allreduce_subcomm (ncp .sum (ncp .abs (ncp .float_power (local_array , ord )), axis = axis ))
742761            recv_buf  =  ncp .power (recv_buf , 1.0  /  ord )
0 commit comments