@@ -483,23 +483,39 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
483483 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
484484 return nccl_allreduce (self .base_comm_nccl , send_buf , recv_buf , op )
485485 else :
486- if recv_buf is None :
487- return self .base_comm .allreduce (send_buf , op )
488- # For MIN and MAX which require recv_buf
489- self .base_comm .Allreduce (send_buf , recv_buf , op )
490- return recv_buf
486+ if is_cuda_aware_mpi or self .engine == "numpy" :
487+ ncp = get_module (self .engine )
488+ # mpi_type = MPI._typedict[send_buf.dtype.char]
489+ recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
490+ self .base_comm .Allreduce (send_buf , recv_buf , op )
491+ return recv_buf
492+ else :
493+ # CuPy with non-CUDA-aware MPI
494+ if recv_buf is None :
495+ return self .base_comm .allreduce (send_buf , op )
496+ # For MIN and MAX which require recv_buf
497+ self .base_comm .Allreduce (send_buf , recv_buf , op )
498+ return recv_buf
491499
492500 def _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
493501 """Allreduce operation with subcommunicator
494502 """
495503 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
496504 return nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
497505 else :
498- if recv_buf is None :
499- return self .sub_comm .allreduce (send_buf , op )
500- # For MIN and MAX which require recv_buf
501- self .sub_comm .Allreduce (send_buf , recv_buf , op )
502- return recv_buf
506+ if is_cuda_aware_mpi or self .engine == "numpy" :
507+ ncp = get_module (self .engine )
508+ # mpi_type = MPI._typedict[send_buf.dtype.char]
509+ recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
510+ self .sub_comm .Allreduce (send_buf , recv_buf , op )
511+ return recv_buf
512+ else :
513+ # CuPy with non-CUDA-aware MPI
514+ if recv_buf is None :
515+ return self .sub_comm .allreduce (send_buf , op )
516+ # For MIN and MAX which require recv_buf
517+ self .sub_comm .Allreduce (send_buf , recv_buf , op )
518+ return recv_buf
503519
504520 def _allgather (self , send_buf , recv_buf = None ):
505521 """Allgather operation
@@ -717,26 +733,29 @@ def _compute_vector_norm(self, local_array: NDArray,
717733 recv_buf = self ._allreduce_subcomm (ncp .count_nonzero (local_array , axis = axis ).astype (ncp .float64 ))
718734 elif ord == ncp .inf :
719735 # Calculate max followed by max reduction
720- # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly
736+ # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
721737 # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
722738 send_buf = ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
723- if self .engine == "cupy" and self .base_comm_nccl is None :
739+ if self .engine == "cupy" and self .base_comm_nccl is None and not is_cuda_aware_mpi :
740+ # CuPy + non-CUDA-aware MPI: This will call non-buffered communication
741+ # which return a list of object - must be copied back to a GPU memory.
724742 recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
725743 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
726744 else :
727745 recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
728- recv_buf = ncp .squeeze (recv_buf , axis = axis )
746+ if self .base_comm_nccl :
747+ recv_buf = ncp .squeeze (recv_buf , axis = axis )
729748 elif ord == - ncp .inf :
730749 # Calculate min followed by min reduction
731- # TODO (tharitt): see the comment above in infinity norm
750+ # See the comment above in + infinity norm
732751 send_buf = ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
733- if self .engine == "cupy" and self .base_comm_nccl is None :
752+ if self .engine == "cupy" and self .base_comm_nccl is None and not is_cuda_aware_mpi :
734753 recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
735754 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
736755 else :
737756 recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MIN )
738- recv_buf = ncp . asarray ( ncp . squeeze ( recv_buf , axis = axis ))
739-
757+ if self . base_comm_nccl :
758+ recv_buf = ncp . asarray ( ncp . squeeze ( recv_buf , axis = axis ))
740759 else :
741760 recv_buf = self ._allreduce_subcomm (ncp .sum (ncp .abs (ncp .float_power (local_array , ord )), axis = axis ))
742761 recv_buf = ncp .power (recv_buf , 1.0 / ord )
0 commit comments