55from  enum  import  Enum 
66
77from  pylops .utils  import  DTypeLike , NDArray 
8+ from  pylops .utils ._internal  import  _value_or_sized_to_tuple 
89from  pylops .utils .backend  import  get_module , get_array_module , get_module_name 
910
1011
@@ -78,7 +79,10 @@ class DistributedArray:
7879    axis : :obj:`int`, optional 
7980        Axis along which distribution occurs. Defaults to ``0``. 
8081    local_shapes : :obj:`list`, optional 
81-         List of tuples representing local shapes at each rank. 
82+         List of tuples or integers representing local shapes at each rank. 
83+     mask : :obj:`list`, optional 
84+         Mask defining subsets of ranks to consider when performing 'global' 
85+         operations on the distributed array such as dot product or norm. 
8286    engine : :obj:`str`, optional 
8387        Engine used to store array (``numpy`` or ``cupy``) 
8488    dtype : :obj:`str`, optional 
@@ -88,7 +92,8 @@ class DistributedArray:
8892    def  __init__ (self , global_shape : Union [Tuple , Integral ],
8993                 base_comm : Optional [MPI .Comm ] =  MPI .COMM_WORLD ,
9094                 partition : Partition  =  Partition .SCATTER , axis : int  =  0 ,
91-                  local_shapes : Optional [List [Tuple ]] =  None ,
95+                  local_shapes : Optional [List [Union [Tuple , Integral ]]] =  None ,
96+                  mask : Optional [List [Integral ]] =  None ,
9297                 engine : Optional [str ] =  "numpy" ,
9398                 dtype : Optional [DTypeLike ] =  np .float64 ):
9499        if  isinstance (global_shape , Integral ):
@@ -100,10 +105,14 @@ def __init__(self, global_shape: Union[Tuple, Integral],
100105            raise  ValueError (f"Should be either { Partition .BROADCAST }   " 
101106                             f"or { Partition .SCATTER }  " )
102107        self .dtype  =  dtype 
103-         self ._global_shape  =  global_shape 
108+         self ._global_shape  =  _value_or_sized_to_tuple ( global_shape ) 
104109        self ._base_comm  =  base_comm 
105110        self ._partition  =  partition 
106111        self ._axis  =  axis 
112+         self ._mask  =  mask 
113+         self ._sub_comm  =  base_comm  if  mask  is  None  else  base_comm .Split (color = mask [base_comm .rank ], key = base_comm .rank )
114+ 
115+         local_shapes  =  local_shapes  if  local_shapes  is  None  else  [_value_or_sized_to_tuple (local_shape ) for  local_shape  in  local_shapes ]
107116        self ._check_local_shapes (local_shapes )
108117        self ._local_shape  =  local_shapes [base_comm .rank ] if  local_shapes  else  local_split (global_shape , base_comm ,
109118                                                                                          partition , axis )
@@ -165,6 +174,16 @@ def local_shape(self):
165174        """ 
166175        return  self ._local_shape 
167176
177+     @property  
178+     def  mask (self ):
179+         """Mask of the Distributed array 
180+ 
181+         Returns 
182+         ------- 
183+         engine : :obj:`list` 
184+         """ 
185+         return  self ._mask 
186+ 
168187    @property  
169188    def  engine (self ):
170189        """Engine of the Distributed array 
@@ -246,6 +265,16 @@ def local_shapes(self):
246265        """ 
247266        return  self .base_comm .allgather (self .local_shape )
248267
268+     @property  
269+     def  sub_comm (self ):
270+         """MPI Sub-Communicator 
271+ 
272+         Returns 
273+         ------- 
274+         sub_comm : :obj:`MPI.Comm` 
275+         """ 
276+         return  self ._sub_comm 
277+ 
249278    def  asarray (self ):
250279        """Global view of the array 
251280
@@ -269,7 +298,8 @@ def to_dist(cls, x: NDArray,
269298                base_comm : MPI .Comm  =  MPI .COMM_WORLD ,
270299                partition : Partition  =  Partition .SCATTER ,
271300                axis : int  =  0 ,
272-                 local_shapes : Optional [List [Tuple ]] =  None ):
301+                 local_shapes : Optional [List [Tuple ]] =  None ,
302+                 mask : Optional [List [Integral ]] =  None ):
273303        """Convert A Global Array to a Distributed Array 
274304
275305        Parameters 
@@ -284,6 +314,9 @@ def to_dist(cls, x: NDArray,
284314            Axis of Distribution 
285315        local_shapes : :obj:`list`, optional 
286316            Local Shapes at each rank. 
317+         mask : :obj:`list`, optional 
318+             Mask defining subsets of ranks to consider when performing 'global' 
319+             operations on the distributed array such as dot product or norm. 
287320
288321        Returns 
289322        ---------- 
@@ -295,6 +328,7 @@ def to_dist(cls, x: NDArray,
295328                                      partition = partition ,
296329                                      axis = axis ,
297330                                      local_shapes = local_shapes ,
331+                                       mask = mask ,
298332                                      engine = get_module_name (get_array_module (x )),
299333                                      dtype = x .dtype )
300334        if  partition  ==  Partition .BROADCAST :
@@ -336,6 +370,12 @@ def _check_partition_shape(self, dist_array):
336370            raise  ValueError (f"Local Array Shape Mismatch - " 
337371                             f"{ self .local_shape }   != { dist_array .local_shape }  " )
338372
373+     def  _check_mask (self , dist_array ):
374+         """Check mask of the Array 
375+         """ 
376+         if  not  np .array_equal (self .mask , dist_array .mask ):
377+             raise  ValueError ("Mask of both the arrays must be same" )
378+ 
339379    def  _allreduce (self , send_buf , recv_buf = None , op : MPI .Op  =  MPI .SUM ):
340380        """MPI Allreduce operation 
341381        """ 
@@ -345,12 +385,22 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
345385        self .base_comm .Allreduce (send_buf , recv_buf , op )
346386        return  recv_buf 
347387
388+     def  _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op  =  MPI .SUM ):
389+         """MPI Allreduce operation with subcommunicator 
390+         """ 
391+         if  recv_buf  is  None :
392+             return  self .sub_comm .allreduce (send_buf , op )
393+         # For MIN and MAX which require recv_buf 
394+         self .sub_comm .Allreduce (send_buf , recv_buf , op )
395+         return  recv_buf 
396+ 
348397    def  __neg__ (self ):
349398        arr  =  DistributedArray (global_shape = self .global_shape ,
350399                               base_comm = self .base_comm ,
351400                               partition = self .partition ,
352401                               axis = self .axis ,
353402                               local_shapes = self .local_shapes ,
403+                                mask = self .mask ,
354404                               engine = self .engine ,
355405                               dtype = self .dtype )
356406        arr [:] =  - self .local_array 
@@ -378,11 +428,13 @@ def add(self, dist_array):
378428        """Distributed Addition of arrays 
379429        """ 
380430        self ._check_partition_shape (dist_array )
431+         self ._check_mask (dist_array )
381432        SumArray  =  DistributedArray (global_shape = self .global_shape ,
382433                                    base_comm = self .base_comm ,
383434                                    dtype = self .dtype ,
384435                                    partition = self .partition ,
385436                                    local_shapes = self .local_shapes ,
437+                                     mask = self .mask ,
386438                                    engine = self .engine ,
387439                                    axis = self .axis )
388440        SumArray [:] =  self .local_array  +  dist_array .local_array 
@@ -392,6 +444,7 @@ def iadd(self, dist_array):
392444        """Distributed In-place Addition of arrays 
393445        """ 
394446        self ._check_partition_shape (dist_array )
447+         self ._check_mask (dist_array )
395448        self [:] =  self .local_array  +  dist_array .local_array 
396449        return  self 
397450
@@ -400,12 +453,14 @@ def multiply(self, dist_array):
400453        """ 
401454        if  isinstance (dist_array , DistributedArray ):
402455            self ._check_partition_shape (dist_array )
456+             self ._check_mask (dist_array )
403457
404458        ProductArray  =  DistributedArray (global_shape = self .global_shape ,
405459                                        base_comm = self .base_comm ,
406460                                        dtype = self .dtype ,
407461                                        partition = self .partition ,
408462                                        local_shapes = self .local_shapes ,
463+                                         mask = self .mask ,
409464                                        engine = self .engine ,
410465                                        axis = self .axis )
411466        if  isinstance (dist_array , DistributedArray ):
@@ -420,13 +475,15 @@ def dot(self, dist_array):
420475        """Distributed Dot Product 
421476        """ 
422477        self ._check_partition_shape (dist_array )
478+         self ._check_mask (dist_array )
479+ 
423480        # Convert to Partition.SCATTER if Partition.BROADCAST 
424481        x  =  DistributedArray .to_dist (x = self .local_array ) \
425482            if  self .partition  is  Partition .BROADCAST  else  self 
426483        y  =  DistributedArray .to_dist (x = dist_array .local_array ) \
427484            if  self .partition  is  Partition .BROADCAST  else  dist_array 
428485        # Flatten the local arrays and calculate dot product 
429-         return  self ._allreduce (np .dot (x .local_array .flatten (), y .local_array .flatten ()))
486+         return  self ._allreduce_subcomm (np .dot (x .local_array .flatten (), y .local_array .flatten ()))
430487
431488    def  _compute_vector_norm (self , local_array : NDArray ,
432489                             axis : int , ord : Optional [int ] =  None ):
@@ -453,20 +510,20 @@ def _compute_vector_norm(self, local_array: NDArray,
453510            raise  ValueError (f"norm-{ ord }   not possible for vectors" )
454511        elif  ord  ==  0 :
455512            # Count non-zero then sum reduction 
456-             recv_buf  =  self ._allreduce (np .count_nonzero (local_array , axis = axis ).astype (np .float64 ))
513+             recv_buf  =  self ._allreduce_subcomm (np .count_nonzero (local_array , axis = axis ).astype (np .float64 ))
457514        elif  ord  ==  np .inf :
458515            # Calculate max followed by max reduction 
459-             recv_buf  =  self ._allreduce (np .max (np .abs (local_array ), axis = axis ).astype (np .float64 ),
460-                                        recv_buf , op = MPI .MAX )
516+             recv_buf  =  self ._allreduce_subcomm (np .max (np .abs (local_array ), axis = axis ).astype (np .float64 ),
517+                                                 recv_buf , op = MPI .MAX )
461518            recv_buf  =  np .squeeze (recv_buf , axis = axis )
462519        elif  ord  ==  - np .inf :
463520            # Calculate min followed by min reduction 
464-             recv_buf  =  self ._allreduce (np .min (np .abs (local_array ), axis = axis ).astype (np .float64 ),
465-                                        recv_buf , op = MPI .MIN )
521+             recv_buf  =  self ._allreduce_subcomm (np .min (np .abs (local_array ), axis = axis ).astype (np .float64 ),
522+                                                 recv_buf , op = MPI .MIN )
466523            recv_buf  =  np .squeeze (recv_buf , axis = axis )
467524
468525        else :
469-             recv_buf  =  self ._allreduce (np .sum (np .abs (np .float_power (local_array , ord )), axis = axis ))
526+             recv_buf  =  self ._allreduce_subcomm (np .sum (np .abs (np .float_power (local_array , ord )), axis = axis ))
470527            recv_buf  =  np .power (recv_buf , 1.  /  ord )
471528        return  recv_buf 
472529
@@ -500,6 +557,7 @@ def conj(self):
500557                                partition = self .partition ,
501558                                axis = self .axis ,
502559                                local_shapes = self .local_shapes ,
560+                                 mask = self .mask ,
503561                                engine = self .engine ,
504562                                dtype = self .dtype )
505563        conj [:] =  self .local_array .conj ()
@@ -513,6 +571,7 @@ def copy(self):
513571                               partition = self .partition ,
514572                               axis = self .axis ,
515573                               local_shapes = self .local_shapes ,
574+                                mask = self .mask ,
516575                               engine = self .engine ,
517576                               dtype = self .dtype )
518577        arr [:] =  self .local_array 
@@ -535,6 +594,7 @@ def ravel(self, order: Optional[str] = "C"):
535594        local_shapes  =  [(np .prod (local_shape , axis = - 1 ), ) for  local_shape  in  self .local_shapes ]
536595        arr  =  DistributedArray (global_shape = np .prod (self .global_shape ),
537596                               local_shapes = local_shapes ,
597+                                mask = self .mask ,
538598                               partition = self .partition ,
539599                               engine = self .engine ,
540600                               dtype = self .dtype )
0 commit comments