@@ -495,7 +495,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
495495 def _allgather (self , send_buf , recv_buf = None ):
496496 """Allgather operation
497497 """
498- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
498+ if deps .nccl_enabled and self . base_comm_nccl :
499499 return nccl_allgather (self .base_comm_nccl , send_buf , recv_buf )
500500 else :
501501 if recv_buf is None :
@@ -518,13 +518,16 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=None):
518518 """ Receive operation
519519 """
520520 # NCCL must be called with recv_buf. Size cannot be inferred from
521- # other arguments and thus cannot dynamically allocated
521+ # other arguments and thus cannot be dynamically allocated
522522 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ) and recv_buf is not None :
523- if count is None :
524- # assuming data will take a space of the whole buffer
525- count = recv_buf .size
526- nccl_recv (self .base_comm_nccl , recv_buf , source , count )
527- return recv_buf
523+ if recv_buf is not None :
524+ if count is None :
525+ # assuming data will take a space of the whole buffer
526+ count = recv_buf .size
527+ nccl_recv (self .base_comm_nccl , recv_buf , source , count )
528+ return recv_buf
529+ else :
530+ raise ValueError ("Using recv with NCCL must also supply receiver buffer " )
528531 else :
529532 # MPI allows a receiver buffer to be optional
530533 if recv_buf is None :
@@ -773,10 +776,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
773776 -------
774777 ghosted_array : :obj:`numpy.ndarray`
775778 Ghosted Array
776-
777779 """
778780 ghosted_array = self .local_array .copy ()
779- ncp = get_module (getattr ( self , " engine" ) )
781+ ncp = get_module (self . engine )
780782 if cells_front is not None :
781783 # cells_front is small array of int. Explicitly use MPI
782784 total_cells_front = self .base_comm .allgather (cells_front ) + [0 ]
@@ -790,7 +792,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
790792 recv_shape = list (recv_shapes [self .rank - 1 ])
791793 recv_shape [self .axis ] = total_cells_front [self .rank ]
792794 recv_buf = ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
793- # Some communication can skip if len(recv_buf) = 0
795+ # Transfer of ghost cells can be skipped if len(recv_buf) = 0
794796 # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
795797 if len (recv_buf ) != 0 :
796798 ghosted_array = ncp .concatenate ([self ._recv (recv_buf , source = self .rank - 1 , tag = 1 ), ghosted_array ], axis = self .axis )
0 commit comments