33from  typing  import  Any , List , Optional , Tuple , Union , NewType 
44
55import  numpy  as  np 
6- import  os 
76from  mpi4py  import  MPI 
7+ from  pylops_mpi .Distributed  import  DistributedMixIn 
88from  pylops .utils  import  DTypeLike , NDArray 
99from  pylops .utils  import  deps  as  pylops_deps   # avoid namespace crashes with pylops_mpi.utils 
1010from  pylops .utils ._internal  import  _value_or_sized_to_tuple 
1111from  pylops .utils .backend  import  get_array_module , get_module , get_module_name 
12+ from  pylops_mpi .utils ._mpi  import  mpi_allreduce , mpi_send 
1213from  pylops_mpi .utils  import  deps 
1314
1415cupy_message  =  pylops_deps .cupy_import ("the DistributedArray module" )
2223
2324NcclCommunicatorType  =  NewType ("NcclCommunicator" , NcclCommunicator )
2425
25- if  int (os .environ .get ("PYLOPS_MPI_CUDA_AWARE" , 0 )):
26-     is_cuda_aware_mpi  =  True 
27- else :
28-     is_cuda_aware_mpi  =  False 
2926
3027class  Partition (Enum ):
3128    r"""Enum class 
@@ -104,7 +101,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] =
104101    return  sub_comm 
105102
106103
107- class  DistributedArray :
104+ class  DistributedArray ( DistributedMixIn ) :
108105    r"""Distributed Numpy Arrays 
109106
110107    Multidimensional NumPy-like distributed arrays. 
@@ -477,44 +474,6 @@ def _check_mask(self, dist_array):
477474        if  not  np .array_equal (self .mask , dist_array .mask ):
478475            raise  ValueError ("Mask of both the arrays must be same" )
479476
480-     def  _allreduce (self , send_buf , recv_buf = None , op : MPI .Op  =  MPI .SUM ):
481-         """Allreduce operation 
482-         """ 
483-         if  deps .nccl_enabled  and  getattr (self , "base_comm_nccl" ):
484-             return  nccl_allreduce (self .base_comm_nccl , send_buf , recv_buf , op )
485-         else :
486-             if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
487-                 ncp  =  get_module (self .engine )
488-                 recv_buf  =  ncp .zeros (send_buf .size , dtype = send_buf .dtype )
489-                 self .base_comm .Allreduce (send_buf , recv_buf , op )
490-                 return  recv_buf 
491-             else :
492-                 # CuPy with non-CUDA-aware MPI 
493-                 if  recv_buf  is  None :
494-                     return  self .base_comm .allreduce (send_buf , op )
495-                 # For MIN and MAX which require recv_buf 
496-                 self .base_comm .Allreduce (send_buf , recv_buf , op )
497-                 return  recv_buf 
498- 
499-     def  _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op  =  MPI .SUM ):
500-         """Allreduce operation with subcommunicator 
501-         """ 
502-         if  deps .nccl_enabled  and  getattr (self , "base_comm_nccl" ):
503-             return  nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
504-         else :
505-             if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
506-                 ncp  =  get_module (self .engine )
507-                 recv_buf  =  ncp .zeros (send_buf .size , dtype = send_buf .dtype )
508-                 self .sub_comm .Allreduce (send_buf , recv_buf , op )
509-                 return  recv_buf 
510-             else :
511-                 # CuPy with non-CUDA-aware MPI 
512-                 if  recv_buf  is  None :
513-                     return  self .sub_comm .allreduce (send_buf , op )
514-                 # For MIN and MAX which require recv_buf 
515-                 self .sub_comm .Allreduce (send_buf , recv_buf , op )
516-                 return  recv_buf 
517- 
518477    def  _allgather (self , send_buf , recv_buf = None ):
519478        """Allgather operation 
520479        """ 
@@ -556,16 +515,9 @@ def _send(self, send_buf, dest, count=None, tag=0):
556515                count  =  send_buf .size 
557516            nccl_send (self .base_comm_nccl , send_buf , dest , count )
558517        else :
559-             if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
560-                 # Determine MPI type based on array dtype 
561-                 mpi_type  =  MPI ._typedict [send_buf .dtype .char ]
562-                 if  count  is  None :
563-                     count  =  send_buf .size 
564-                 self .base_comm .Send ([send_buf , count , mpi_type ], dest = dest , tag = tag )
565-             else :
566-                 # Uses CuPy without CUDA-aware MPI 
567-                 self .base_comm .send (send_buf , dest , tag )
568- 
518+             mpi_send (self .base_comm ,
519+                      send_buf , dest , count , tag = tag ,
520+                      engine = self .engine )
569521
570522    def  _recv (self , recv_buf = None , source = 0 , count = None , tag = 0 ):
571523        """Receive operation 
@@ -579,7 +531,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0):
579531            return  recv_buf 
580532        else :
581533            # NumPy + MPI will benefit from buffered communication regardless of MPI installation 
582-             if  is_cuda_aware_mpi  or  self .engine  ==  "numpy" :
534+             if  deps . cuda_aware_mpi_enabled  or  self .engine  ==  "numpy" :
583535                ncp  =  get_module (self .engine )
584536                if  recv_buf  is  None :
585537                    if  count  is  None :
@@ -734,7 +686,7 @@ def _compute_vector_norm(self, local_array: NDArray,
734686            # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly 
735687            # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs 
736688            send_buf  =  ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
737-             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None  and  not  is_cuda_aware_mpi :
689+             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None  and  not  deps . cuda_aware_mpi_enabled :
738690                # CuPy + non-CUDA-aware MPI: This will call non-buffered communication  
739691                # which return a list of object - must be copied back to a GPU memory. 
740692                recv_buf  =  self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
@@ -750,7 +702,7 @@ def _compute_vector_norm(self, local_array: NDArray,
750702            # Calculate min followed by min reduction 
751703            # See the comment above in +infinity norm 
752704            send_buf  =  ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
753-             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None  and  not  is_cuda_aware_mpi :
705+             if  self .engine  ==  "cupy"  and  self .base_comm_nccl  is  None  and  not  deps . cuda_aware_mpi_enabled :
754706                recv_buf  =  self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
755707                recv_buf  =  ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
756708            else :
0 commit comments