33from typing import Any , List , Optional , Tuple , Union , NewType
44
55import numpy as np
6- import os
76from mpi4py import MPI
7+ from pylops_mpi .Distributed import DistributedMixIn
88from pylops .utils import DTypeLike , NDArray
99from pylops .utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
1010from pylops .utils ._internal import _value_or_sized_to_tuple
1111from pylops .utils .backend import get_array_module , get_module , get_module_name
12+ from pylops_mpi .utils ._mpi import mpi_allreduce , mpi_send
1213from pylops_mpi .utils import deps
1314
1415cupy_message = pylops_deps .cupy_import ("the DistributedArray module" )
2223
2324NcclCommunicatorType = NewType ("NcclCommunicator" , NcclCommunicator )
2425
25- if int (os .environ .get ("PYLOPS_MPI_CUDA_AWARE" , 0 )):
26- is_cuda_aware_mpi = True
27- else :
28- is_cuda_aware_mpi = False
2926
3027class Partition (Enum ):
3128 r"""Enum class
@@ -104,7 +101,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] =
104101 return sub_comm
105102
106103
107- class DistributedArray :
104+ class DistributedArray ( DistributedMixIn ) :
108105 r"""Distributed Numpy Arrays
109106
110107 Multidimensional NumPy-like distributed arrays.
@@ -477,44 +474,6 @@ def _check_mask(self, dist_array):
477474 if not np .array_equal (self .mask , dist_array .mask ):
478475 raise ValueError ("Mask of both the arrays must be same" )
479476
480- def _allreduce (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
481- """Allreduce operation
482- """
483- if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
484- return nccl_allreduce (self .base_comm_nccl , send_buf , recv_buf , op )
485- else :
486- if is_cuda_aware_mpi or self .engine == "numpy" :
487- ncp = get_module (self .engine )
488- recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
489- self .base_comm .Allreduce (send_buf , recv_buf , op )
490- return recv_buf
491- else :
492- # CuPy with non-CUDA-aware MPI
493- if recv_buf is None :
494- return self .base_comm .allreduce (send_buf , op )
495- # For MIN and MAX which require recv_buf
496- self .base_comm .Allreduce (send_buf , recv_buf , op )
497- return recv_buf
498-
499- def _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
500- """Allreduce operation with subcommunicator
501- """
502- if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
503- return nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
504- else :
505- if is_cuda_aware_mpi or self .engine == "numpy" :
506- ncp = get_module (self .engine )
507- recv_buf = ncp .zeros (send_buf .size , dtype = send_buf .dtype )
508- self .sub_comm .Allreduce (send_buf , recv_buf , op )
509- return recv_buf
510- else :
511- # CuPy with non-CUDA-aware MPI
512- if recv_buf is None :
513- return self .sub_comm .allreduce (send_buf , op )
514- # For MIN and MAX which require recv_buf
515- self .sub_comm .Allreduce (send_buf , recv_buf , op )
516- return recv_buf
517-
518477 def _allgather (self , send_buf , recv_buf = None ):
519478 """Allgather operation
520479 """
@@ -556,16 +515,9 @@ def _send(self, send_buf, dest, count=None, tag=0):
556515 count = send_buf .size
557516 nccl_send (self .base_comm_nccl , send_buf , dest , count )
558517 else :
559- if is_cuda_aware_mpi or self .engine == "numpy" :
560- # Determine MPI type based on array dtype
561- mpi_type = MPI ._typedict [send_buf .dtype .char ]
562- if count is None :
563- count = send_buf .size
564- self .base_comm .Send ([send_buf , count , mpi_type ], dest = dest , tag = tag )
565- else :
566- # Uses CuPy without CUDA-aware MPI
567- self .base_comm .send (send_buf , dest , tag )
568-
518+ mpi_send (self .base_comm ,
519+ send_buf , dest , count , tag = tag ,
520+ engine = self .engine )
569521
570522 def _recv (self , recv_buf = None , source = 0 , count = None , tag = 0 ):
571523 """Receive operation
@@ -579,7 +531,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0):
579531 return recv_buf
580532 else :
581533 # NumPy + MPI will benefit from buffered communication regardless of MPI installation
582- if is_cuda_aware_mpi or self .engine == "numpy" :
534+ if deps . cuda_aware_mpi_enabled or self .engine == "numpy" :
583535 ncp = get_module (self .engine )
584536 if recv_buf is None :
585537 if count is None :
@@ -734,7 +686,7 @@ def _compute_vector_norm(self, local_array: NDArray,
734686 # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
735687 # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
736688 send_buf = ncp .max (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
737- if self .engine == "cupy" and self .base_comm_nccl is None and not is_cuda_aware_mpi :
689+ if self .engine == "cupy" and self .base_comm_nccl is None and not deps . cuda_aware_mpi_enabled :
738690 # CuPy + non-CUDA-aware MPI: This will call non-buffered communication
739691 # which return a list of object - must be copied back to a GPU memory.
740692 recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
@@ -750,7 +702,7 @@ def _compute_vector_norm(self, local_array: NDArray,
750702 # Calculate min followed by min reduction
751703 # See the comment above in +infinity norm
752704 send_buf = ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
753- if self .engine == "cupy" and self .base_comm_nccl is None and not is_cuda_aware_mpi :
705+ if self .engine == "cupy" and self .base_comm_nccl is None and not deps . cuda_aware_mpi_enabled :
754706 recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
755707 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
756708 else :
0 commit comments