1414nccl_message = deps .nccl_import ("the DistributedArray module" )
1515
1616if nccl_message is None and cupy_message is None :
17- from pylops_mpi .utils ._nccl import nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split
17+ from pylops_mpi .utils ._nccl import nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split , nccl_send , nccl_recv
1818 from cupy .cuda .nccl import NcclCommunicator
1919else :
2020 NcclCommunicator = Any
@@ -504,7 +504,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
504504 def _allgather (self , send_buf , recv_buf = None ):
505505 """Allgather operation
506506 """
507- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
507+ if deps .nccl_enabled and self . base_comm_nccl :
508508 return nccl_allgather (self .base_comm_nccl , send_buf , recv_buf )
509509 else :
510510 if recv_buf is None :
@@ -521,6 +521,37 @@ def _allgather_subcomm(self, send_buf, recv_buf=None):
521521 if recv_buf is None :
522522 return self .sub_comm .allgather (send_buf )
523523 self .sub_comm .Allgather (send_buf , recv_buf )
524+
525+ def _send (self , send_buf , dest , count = None , tag = None ):
526+ """ Send operation
527+ """
528+ if deps .nccl_enabled and self .base_comm_nccl :
529+ if count is None :
530+ # assuming sending the whole array
531+ count = send_buf .size
532+ nccl_send (self .base_comm_nccl , send_buf , dest , count )
533+ else :
534+ self .base_comm .Send (send_buf , dest , tag )
535+
536+ def _recv (self , recv_buf = None , source = 0 , count = None , tag = None ):
537+ """ Receive operation
538+ """
539+ # NCCL must be called with recv_buf. Size cannot be inferred from
540+ # other arguments and thus cannot be dynamically allocated
541+ if deps .nccl_enabled and self .base_comm_nccl and recv_buf is not None :
542+ if recv_buf is not None :
543+ if count is None :
544+ # assuming data will take a space of the whole buffer
545+ count = recv_buf .size
546+ nccl_recv (self .base_comm_nccl , recv_buf , source , count )
547+ return recv_buf
548+ else :
549+ raise ValueError ("Using recv with NCCL must also supply receiver buffer " )
550+ else :
551+ # MPI allows a receiver buffer to be optional
552+ if recv_buf is None :
553+ return self .base_comm .recv (source = source , tag = tag )
554+ self .base_comm .Recv (buf = recv_buf , source = source , tag = tag )
524555 return recv_buf
525556
526557 def __neg__ (self ):
@@ -560,6 +591,7 @@ def add(self, dist_array):
560591 self ._check_mask (dist_array )
561592 SumArray = DistributedArray (global_shape = self .global_shape ,
562593 base_comm = self .base_comm ,
594+ base_comm_nccl = self .base_comm_nccl ,
563595 dtype = self .dtype ,
564596 partition = self .partition ,
565597 local_shapes = self .local_shapes ,
@@ -586,6 +618,7 @@ def multiply(self, dist_array):
586618
587619 ProductArray = DistributedArray (global_shape = self .global_shape ,
588620 base_comm = self .base_comm ,
621+ base_comm_nccl = self .base_comm_nccl ,
589622 dtype = self .dtype ,
590623 partition = self .partition ,
591624 local_shapes = self .local_shapes ,
@@ -736,6 +769,8 @@ def ravel(self, order: Optional[str] = "C"):
736769 """
737770 local_shapes = [(np .prod (local_shape , axis = - 1 ), ) for local_shape in self .local_shapes ]
738771 arr = DistributedArray (global_shape = np .prod (self .global_shape ),
772+ base_comm = self .base_comm ,
773+ base_comm_nccl = self .base_comm_nccl ,
739774 local_shapes = local_shapes ,
740775 mask = self .mask ,
741776 partition = self .partition ,
@@ -764,41 +799,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
764799 -------
765800 ghosted_array : :obj:`numpy.ndarray`
766801 Ghosted Array
767-
768802 """
769803 ghosted_array = self .local_array .copy ()
804+ ncp = get_module (self .engine )
770805 if cells_front is not None :
771- total_cells_front = self ._allgather (cells_front ) + [0 ]
806+ # cells_front is small array of int. Explicitly use MPI
807+ total_cells_front = self .base_comm .allgather (cells_front ) + [0 ]
772808 # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
773809 cells_front = total_cells_front [self .rank + 1 ]
810+ send_buf = ncp .take (self .local_array , ncp .arange (- cells_front , 0 ), axis = self .axis )
811+ recv_shapes = self .local_shapes
774812 if self .rank != 0 :
775- ghosted_array = np .concatenate ([self .base_comm .recv (source = self .rank - 1 , tag = 1 ), ghosted_array ],
776- axis = self .axis )
777- if self .rank != self .size - 1 :
813+ # from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
814+ # in every dimension except the shape at axis=self.axis
815+ recv_shape = list (recv_shapes [self .rank - 1 ])
816+ recv_shape [self .axis ] = total_cells_front [self .rank ]
817+ recv_buf = ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
818+ # Transfer of ghost cells can be skipped if len(recv_buf) = 0
819+ # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
820+ if len (recv_buf ) != 0 :
821+ ghosted_array = ncp .concatenate ([self ._recv (recv_buf , source = self .rank - 1 , tag = 1 ), ghosted_array ], axis = self .axis )
822+ # The skip in sender is to match with what described in receiver
823+ if self .rank != self .size - 1 and len (send_buf ) != 0 :
778824 if cells_front > self .local_shape [self .axis ]:
779825 raise ValueError (f"Local Shape at rank={ self .rank } along axis={ self .axis } "
780826 f"should be > { cells_front } : dim({ self .axis } ) "
781827 f"{ self .local_shape [self .axis ]} < { cells_front } ; "
782828 f"to achieve this use NUM_PROCESSES <= "
783829 f"{ max (1 , self .global_shape [self .axis ] // cells_front )} " )
784- self .base_comm .send (np .take (self .local_array , np .arange (- cells_front , 0 ), axis = self .axis ),
785- dest = self .rank + 1 , tag = 1 )
830+ self ._send (send_buf , dest = self .rank + 1 , tag = 1 )
786831 if cells_back is not None :
787- total_cells_back = self ._allgather (cells_back ) + [0 ]
832+ total_cells_back = self .base_comm . allgather (cells_back ) + [0 ]
788833 # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
789834 cells_back = total_cells_back [self .rank - 1 ]
790- if self .rank != 0 :
835+ send_buf = ncp .take (self .local_array , ncp .arange (cells_back ), axis = self .axis )
836+ # Same reasoning as sending cell front applied
837+ recv_shapes = self .local_shapes
838+ if self .rank != 0 and len (send_buf ) != 0 :
791839 if cells_back > self .local_shape [self .axis ]:
792840 raise ValueError (f"Local Shape at rank={ self .rank } along axis={ self .axis } "
793841 f"should be > { cells_back } : dim({ self .axis } ) "
794842 f"{ self .local_shape [self .axis ]} < { cells_back } ; "
795843 f"to achieve this use NUM_PROCESSES <= "
796844 f"{ max (1 , self .global_shape [self .axis ] // cells_back )} " )
797- self .base_comm .send (np .take (self .local_array , np .arange (cells_back ), axis = self .axis ),
798- dest = self .rank - 1 , tag = 0 )
845+ self ._send (send_buf , dest = self .rank - 1 , tag = 0 )
799846 if self .rank != self .size - 1 :
800- ghosted_array = np .append (ghosted_array , self .base_comm .recv (source = self .rank + 1 , tag = 0 ),
801- axis = self .axis )
847+ recv_shape = list (recv_shapes [self .rank + 1 ])
848+ recv_shape [self .axis ] = total_cells_back [self .rank ]
849+ recv_buf = ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
850+ if len (recv_buf ) != 0 :
851+ ghosted_array = ncp .append (ghosted_array , self ._recv (recv_buf , source = self .rank + 1 , tag = 0 ),
852+ axis = self .axis )
802853 return ghosted_array
803854
804855 def __repr__ (self ):
0 commit comments