1414nccl_message = deps .nccl_import ("the DistributedArray module" )
1515
1616if nccl_message is None and cupy_message is None :
17- from pylops_mpi .utils ._nccl import nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split
17+ from pylops_mpi .utils ._nccl import nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split , nccl_send , nccl_recv
1818 from cupy .cuda .nccl import NcclCommunicator
1919else :
2020 NcclCommunicator = Any
@@ -495,14 +495,46 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
495495 def _allgather (self , send_buf , recv_buf = None ):
496496 """Allgather operation
497497 """
498- if deps .nccl_enabled and getattr ( self , " base_comm_nccl" ) :
498+ if deps .nccl_enabled and self . base_comm_nccl :
499499 return nccl_allgather (self .base_comm_nccl , send_buf , recv_buf )
500500 else :
501501 if recv_buf is None :
502502 return self .base_comm .allgather (send_buf )
503503 self .base_comm .Allgather (send_buf , recv_buf )
504504 return recv_buf
505505
506+ def _send (self , send_buf , dest , count = None , tag = None ):
507+ """ Send operation
508+ """
509+ if deps .nccl_enabled and self .base_comm_nccl :
510+ if count is None :
511+ # assuming sending the whole array
512+ count = send_buf .size
513+ nccl_send (self .base_comm_nccl , send_buf , dest , count )
514+ else :
515+ self .base_comm .Send (send_buf , dest , tag )
516+
517+ def _recv (self , recv_buf = None , source = 0 , count = None , tag = None ):
518+ """ Receive operation
519+ """
520+ # NCCL must be called with recv_buf. Size cannot be inferred from
521+ # other arguments and thus cannot be dynamically allocated
522+ if deps .nccl_enabled and self .base_comm_nccl and recv_buf is not None :
523+ if recv_buf is not None :
524+ if count is None :
525+ # assuming data will take a space of the whole buffer
526+ count = recv_buf .size
527+ nccl_recv (self .base_comm_nccl , recv_buf , source , count )
528+ return recv_buf
529+ else :
530+ raise ValueError ("Using recv with NCCL must also supply receiver buffer " )
531+ else :
532+ # MPI allows a receiver buffer to be optional
533+ if recv_buf is None :
534+ return self .base_comm .recv (source = source , tag = tag )
535+ self .base_comm .Recv (buf = recv_buf , source = source , tag = tag )
536+ return recv_buf
537+
506538 def __neg__ (self ):
507539 arr = DistributedArray (global_shape = self .global_shape ,
508540 base_comm = self .base_comm ,
@@ -540,6 +572,7 @@ def add(self, dist_array):
540572 self ._check_mask (dist_array )
541573 SumArray = DistributedArray (global_shape = self .global_shape ,
542574 base_comm = self .base_comm ,
575+ base_comm_nccl = self .base_comm_nccl ,
543576 dtype = self .dtype ,
544577 partition = self .partition ,
545578 local_shapes = self .local_shapes ,
@@ -566,6 +599,7 @@ def multiply(self, dist_array):
566599
567600 ProductArray = DistributedArray (global_shape = self .global_shape ,
568601 base_comm = self .base_comm ,
602+ base_comm_nccl = self .base_comm_nccl ,
569603 dtype = self .dtype ,
570604 partition = self .partition ,
571605 local_shapes = self .local_shapes ,
@@ -716,6 +750,8 @@ def ravel(self, order: Optional[str] = "C"):
716750 """
717751 local_shapes = [(np .prod (local_shape , axis = - 1 ), ) for local_shape in self .local_shapes ]
718752 arr = DistributedArray (global_shape = np .prod (self .global_shape ),
753+ base_comm = self .base_comm ,
754+ base_comm_nccl = self .base_comm_nccl ,
719755 local_shapes = local_shapes ,
720756 mask = self .mask ,
721757 partition = self .partition ,
@@ -744,41 +780,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
744780 -------
745781 ghosted_array : :obj:`numpy.ndarray`
746782 Ghosted Array
747-
748783 """
749784 ghosted_array = self .local_array .copy ()
785+ ncp = get_module (self .engine )
750786 if cells_front is not None :
751- total_cells_front = self ._allgather (cells_front ) + [0 ]
787+ # cells_front is small array of int. Explicitly use MPI
788+ total_cells_front = self .base_comm .allgather (cells_front ) + [0 ]
752789 # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1)
753790 cells_front = total_cells_front [self .rank + 1 ]
791+ send_buf = ncp .take (self .local_array , ncp .arange (- cells_front , 0 ), axis = self .axis )
792+ recv_shapes = self .local_shapes
754793 if self .rank != 0 :
755- ghosted_array = np .concatenate ([self .base_comm .recv (source = self .rank - 1 , tag = 1 ), ghosted_array ],
756- axis = self .axis )
757- if self .rank != self .size - 1 :
794+ # from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1)
795+ # in every dimension except the shape at axis=self.axis
796+ recv_shape = list (recv_shapes [self .rank - 1 ])
797+ recv_shape [self .axis ] = total_cells_front [self .rank ]
798+ recv_buf = ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
799+ # Transfer of ghost cells can be skipped if len(recv_buf) = 0
800+ # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
801+ if len (recv_buf ) != 0 :
802+ ghosted_array = ncp .concatenate ([self ._recv (recv_buf , source = self .rank - 1 , tag = 1 ), ghosted_array ], axis = self .axis )
803+ # The skip in sender is to match with what described in receiver
804+ if self .rank != self .size - 1 and len (send_buf ) != 0 :
758805 if cells_front > self .local_shape [self .axis ]:
759806 raise ValueError (f"Local Shape at rank={ self .rank } along axis={ self .axis } "
760807 f"should be > { cells_front } : dim({ self .axis } ) "
761808 f"{ self .local_shape [self .axis ]} < { cells_front } ; "
762809 f"to achieve this use NUM_PROCESSES <= "
763810 f"{ max (1 , self .global_shape [self .axis ] // cells_front )} " )
764- self .base_comm .send (np .take (self .local_array , np .arange (- cells_front , 0 ), axis = self .axis ),
765- dest = self .rank + 1 , tag = 1 )
811+ self ._send (send_buf , dest = self .rank + 1 , tag = 1 )
766812 if cells_back is not None :
767- total_cells_back = self ._allgather (cells_back ) + [0 ]
813+ total_cells_back = self .base_comm . allgather (cells_back ) + [0 ]
768814 # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
769815 cells_back = total_cells_back [self .rank - 1 ]
770- if self .rank != 0 :
816+ send_buf = ncp .take (self .local_array , ncp .arange (cells_back ), axis = self .axis )
817+ # Same reasoning as sending cell front applied
818+ recv_shapes = self .local_shapes
819+ if self .rank != 0 and len (send_buf ) != 0 :
771820 if cells_back > self .local_shape [self .axis ]:
772821 raise ValueError (f"Local Shape at rank={ self .rank } along axis={ self .axis } "
773822 f"should be > { cells_back } : dim({ self .axis } ) "
774823 f"{ self .local_shape [self .axis ]} < { cells_back } ; "
775824 f"to achieve this use NUM_PROCESSES <= "
776825 f"{ max (1 , self .global_shape [self .axis ] // cells_back )} " )
777- self .base_comm .send (np .take (self .local_array , np .arange (cells_back ), axis = self .axis ),
778- dest = self .rank - 1 , tag = 0 )
826+ self ._send (send_buf , dest = self .rank - 1 , tag = 0 )
779827 if self .rank != self .size - 1 :
780- ghosted_array = np .append (ghosted_array , self .base_comm .recv (source = self .rank + 1 , tag = 0 ),
781- axis = self .axis )
828+ recv_shape = list (recv_shapes [self .rank + 1 ])
829+ recv_shape [self .axis ] = total_cells_back [self .rank ]
830+ recv_buf = ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
831+ if len (recv_buf ) != 0 :
832+ ghosted_array = ncp .append (ghosted_array , self ._recv (recv_buf , source = self .rank + 1 , tag = 0 ),
833+ axis = self .axis )
782834 return ghosted_array
783835
784836 def __repr__ (self ):
0 commit comments