1414nccl_message  =  deps .nccl_import ("the DistributedArray module" )
1515
1616if  nccl_message  is  None  and  cupy_message  is  None :
17-     from  pylops_mpi .utils ._nccl  import  nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split 
17+     from  pylops_mpi .utils ._nccl  import  nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split ,  nccl_send ,  nccl_recv 
1818    from  cupy .cuda .nccl  import  NcclCommunicator 
1919else :
2020    NcclCommunicator  =  Any 
@@ -495,14 +495,46 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
495495    def  _allgather (self , send_buf , recv_buf = None ):
496496        """Allgather operation 
497497        """ 
498-         if  deps .nccl_enabled  and  getattr ( self ,  " base_comm_nccl" ) :
498+         if  deps .nccl_enabled  and  self . base_comm_nccl :
499499            return  nccl_allgather (self .base_comm_nccl , send_buf , recv_buf )
500500        else :
501501            if  recv_buf  is  None :
502502                return  self .base_comm .allgather (send_buf )
503503            self .base_comm .Allgather (send_buf , recv_buf )
504504            return  recv_buf 
505505
506+     def  _send (self , send_buf , dest , count = None , tag = None ):
507+         """ Send operation 
508+         """ 
509+         if  deps .nccl_enabled  and  self .base_comm_nccl :
510+             if  count  is  None :
511+                 # assuming sending the whole array 
512+                 count  =  send_buf .size 
513+             nccl_send (self .base_comm_nccl , send_buf , dest , count )
514+         else :
515+             self .base_comm .Send (send_buf , dest , tag )
516+ 
517+     def  _recv (self , recv_buf = None , source = 0 , count = None , tag = None ):
518+         """ Receive operation 
519+         """ 
520+         # NCCL must be called with recv_buf. Size cannot be inferred from 
521+         # other arguments and thus cannot be dynamically allocated 
522+         if  deps .nccl_enabled  and  self .base_comm_nccl  and  recv_buf  is  not   None :
523+             if  recv_buf  is  not   None :
524+                 if  count  is  None :
525+                     # assuming data will take a space of the whole buffer 
526+                     count  =  recv_buf .size 
527+                 nccl_recv (self .base_comm_nccl , recv_buf , source , count )
528+                 return  recv_buf 
529+             else :
530+                 raise  ValueError ("Using recv with NCCL must also supply receiver buffer " )
531+         else :
532+             # MPI allows a receiver buffer to be optional 
533+             if  recv_buf  is  None :
534+                 return  self .base_comm .recv (source = source , tag = tag )
535+             self .base_comm .Recv (buf = recv_buf , source = source , tag = tag )
536+             return  recv_buf 
537+ 
506538    def  __neg__ (self ):
507539        arr  =  DistributedArray (global_shape = self .global_shape ,
508540                               base_comm = self .base_comm ,
@@ -540,6 +572,7 @@ def add(self, dist_array):
540572        self ._check_mask (dist_array )
541573        SumArray  =  DistributedArray (global_shape = self .global_shape ,
542574                                    base_comm = self .base_comm ,
575+                                     base_comm_nccl = self .base_comm_nccl ,
543576                                    dtype = self .dtype ,
544577                                    partition = self .partition ,
545578                                    local_shapes = self .local_shapes ,
@@ -566,6 +599,7 @@ def multiply(self, dist_array):
566599
567600        ProductArray  =  DistributedArray (global_shape = self .global_shape ,
568601                                        base_comm = self .base_comm ,
602+                                         base_comm_nccl = self .base_comm_nccl ,
569603                                        dtype = self .dtype ,
570604                                        partition = self .partition ,
571605                                        local_shapes = self .local_shapes ,
@@ -716,6 +750,8 @@ def ravel(self, order: Optional[str] = "C"):
716750        """ 
717751        local_shapes  =  [(np .prod (local_shape , axis = - 1 ), ) for  local_shape  in  self .local_shapes ]
718752        arr  =  DistributedArray (global_shape = np .prod (self .global_shape ),
753+                                base_comm = self .base_comm ,
754+                                base_comm_nccl = self .base_comm_nccl ,
719755                               local_shapes = local_shapes ,
720756                               mask = self .mask ,
721757                               partition = self .partition ,
@@ -744,41 +780,57 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
744780        ------- 
745781        ghosted_array : :obj:`numpy.ndarray` 
746782            Ghosted Array 
747- 
748783        """ 
749784        ghosted_array  =  self .local_array .copy ()
785+         ncp  =  get_module (self .engine )
750786        if  cells_front  is  not   None :
751-             total_cells_front  =  self ._allgather (cells_front ) +  [0 ]
787+             # cells_front is small array of int. Explicitly use MPI 
788+             total_cells_front  =  self .base_comm .allgather (cells_front ) +  [0 ]
752789            # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) 
753790            cells_front  =  total_cells_front [self .rank  +  1 ]
791+             send_buf  =  ncp .take (self .local_array , ncp .arange (- cells_front , 0 ), axis = self .axis )
792+             recv_shapes  =  self .local_shapes 
754793            if  self .rank  !=  0 :
755-                 ghosted_array  =  np .concatenate ([self .base_comm .recv (source = self .rank  -  1 , tag = 1 ), ghosted_array ],
756-                                                axis = self .axis )
757-             if  self .rank  !=  self .size  -  1 :
794+                 # from receiver's perspective (rank), the recv buffer have the same shape as the sender's array (rank-1) 
795+                 # in every dimension except the shape at axis=self.axis 
796+                 recv_shape  =  list (recv_shapes [self .rank  -  1 ])
797+                 recv_shape [self .axis ] =  total_cells_front [self .rank ]
798+                 recv_buf  =  ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
799+                 # Transfer of ghost cells can be skipped if len(recv_buf) = 0 
800+                 # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory 
801+                 if  len (recv_buf ) !=  0 :
802+                     ghosted_array  =  ncp .concatenate ([self ._recv (recv_buf , source = self .rank  -  1 , tag = 1 ), ghosted_array ], axis = self .axis )
803+             # The skip in sender is to match with what described in receiver 
804+             if  self .rank  !=  self .size  -  1  and  len (send_buf ) !=  0 :
758805                if  cells_front  >  self .local_shape [self .axis ]:
759806                    raise  ValueError (f"Local Shape at rank={ self .rank }   along axis={ self .axis }   " 
760807                                     f"should be > { cells_front }  : dim({ self .axis }  ) " 
761808                                     f"{ self .local_shape [self .axis ]}   < { cells_front }  ; " 
762809                                     f"to achieve this use NUM_PROCESSES <= " 
763810                                     f"{ max (1 , self .global_shape [self .axis ] //  cells_front )}  " )
764-                 self .base_comm .send (np .take (self .local_array , np .arange (- cells_front , 0 ), axis = self .axis ),
765-                                     dest = self .rank  +  1 , tag = 1 )
811+                 self ._send (send_buf , dest = self .rank  +  1 , tag = 1 )
766812        if  cells_back  is  not   None :
767-             total_cells_back  =  self ._allgather (cells_back ) +  [0 ]
813+             total_cells_back  =  self .base_comm . allgather (cells_back ) +  [0 ]
768814            # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) 
769815            cells_back  =  total_cells_back [self .rank  -  1 ]
770-             if  self .rank  !=  0 :
816+             send_buf  =  ncp .take (self .local_array , ncp .arange (cells_back ), axis = self .axis )
817+             # Same reasoning as sending cell front applied 
818+             recv_shapes  =  self .local_shapes 
819+             if  self .rank  !=  0  and  len (send_buf ) !=  0 :
771820                if  cells_back  >  self .local_shape [self .axis ]:
772821                    raise  ValueError (f"Local Shape at rank={ self .rank }   along axis={ self .axis }   " 
773822                                     f"should be > { cells_back }  : dim({ self .axis }  ) " 
774823                                     f"{ self .local_shape [self .axis ]}   < { cells_back }  ; " 
775824                                     f"to achieve this use NUM_PROCESSES <= " 
776825                                     f"{ max (1 , self .global_shape [self .axis ] //  cells_back )}  " )
777-                 self .base_comm .send (np .take (self .local_array , np .arange (cells_back ), axis = self .axis ),
778-                                     dest = self .rank  -  1 , tag = 0 )
826+                 self ._send (send_buf , dest = self .rank  -  1 , tag = 0 )
779827            if  self .rank  !=  self .size  -  1 :
780-                 ghosted_array  =  np .append (ghosted_array , self .base_comm .recv (source = self .rank  +  1 , tag = 0 ),
781-                                           axis = self .axis )
828+                 recv_shape  =  list (recv_shapes [self .rank  +  1 ])
829+                 recv_shape [self .axis ] =  total_cells_back [self .rank ]
830+                 recv_buf  =  ncp .zeros (recv_shape , dtype = ghosted_array .dtype )
831+                 if  len (recv_buf ) !=  0 :
832+                     ghosted_array  =  ncp .append (ghosted_array , self ._recv (recv_buf , source = self .rank  +  1 , tag = 0 ),
833+                                                axis = self .axis )
782834        return  ghosted_array 
783835
784836    def  __repr__ (self ):
0 commit comments