Skip to content

Commit cc699df

Browse files
committed
feat: added mask to DistributedArray
1 parent 6d3b1e8 commit cc699df

File tree

3 files changed

+86
-18
lines changed

3 files changed

+86
-18
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import Enum
66

77
from pylops.utils import DTypeLike, NDArray
8+
from pylops.utils._internal import _value_or_sized_to_tuple
89
from pylops.utils.backend import get_module, get_array_module, get_module_name
910

1011

@@ -78,7 +79,10 @@ class DistributedArray:
7879
axis : :obj:`int`, optional
7980
Axis along which distribution occurs. Defaults to ``0``.
8081
local_shapes : :obj:`list`, optional
81-
List of tuples representing local shapes at each rank.
82+
List of tuples or integers representing local shapes at each rank.
83+
mask : :obj:`list`, optional
84+
Mask defining subsets of ranks to consider when performing 'global'
85+
operations on the distributed array such as dot product or norm.
8286
engine : :obj:`str`, optional
8387
Engine used to store array (``numpy`` or ``cupy``)
8488
dtype : :obj:`str`, optional
@@ -88,7 +92,8 @@ class DistributedArray:
8892
def __init__(self, global_shape: Union[Tuple, Integral],
8993
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
9094
partition: Partition = Partition.SCATTER, axis: int = 0,
91-
local_shapes: Optional[List[Tuple]] = None,
95+
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
96+
mask: Optional[List[Integral]] = None,
9297
engine: Optional[str] = "numpy",
9398
dtype: Optional[DTypeLike] = np.float64):
9499
if isinstance(global_shape, Integral):
@@ -100,10 +105,14 @@ def __init__(self, global_shape: Union[Tuple, Integral],
100105
raise ValueError(f"Should be either {Partition.BROADCAST} "
101106
f"or {Partition.SCATTER}")
102107
self.dtype = dtype
103-
self._global_shape = global_shape
108+
self._global_shape = _value_or_sized_to_tuple(global_shape)
104109
self._base_comm = base_comm
105110
self._partition = partition
106111
self._axis = axis
112+
self._mask = mask
113+
self._sub_comm = base_comm if mask is None else base_comm.Split(color=mask[base_comm.rank], key=base_comm.rank)
114+
115+
local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
107116
self._check_local_shapes(local_shapes)
108117
self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm,
109118
partition, axis)
@@ -165,6 +174,16 @@ def local_shape(self):
165174
"""
166175
return self._local_shape
167176

177+
@property
178+
def mask(self):
179+
"""Mask of the Distributed array
180+
181+
Returns
182+
-------
183+
engine : :obj:`list`
184+
"""
185+
return self._mask
186+
168187
@property
169188
def engine(self):
170189
"""Engine of the Distributed array
@@ -246,6 +265,16 @@ def local_shapes(self):
246265
"""
247266
return self.base_comm.allgather(self.local_shape)
248267

268+
@property
269+
def sub_comm(self):
270+
"""MPI Sub-Communicator
271+
272+
Returns
273+
-------
274+
sub_comm : :obj:`MPI.Comm`
275+
"""
276+
return self._sub_comm
277+
249278
def asarray(self):
250279
"""Global view of the array
251280
@@ -269,7 +298,8 @@ def to_dist(cls, x: NDArray,
269298
base_comm: MPI.Comm = MPI.COMM_WORLD,
270299
partition: Partition = Partition.SCATTER,
271300
axis: int = 0,
272-
local_shapes: Optional[List[Tuple]] = None):
301+
local_shapes: Optional[List[Tuple]] = None,
302+
mask: Optional[List[Integral]] = None):
273303
"""Convert A Global Array to a Distributed Array
274304
275305
Parameters
@@ -284,6 +314,9 @@ def to_dist(cls, x: NDArray,
284314
Axis of Distribution
285315
local_shapes : :obj:`list`, optional
286316
Local Shapes at each rank.
317+
mask : :obj:`list`, optional
318+
Mask defining subsets of ranks to consider when performing 'global'
319+
operations on the distributed array such as dot product or norm.
287320
288321
Returns
289322
----------
@@ -295,6 +328,7 @@ def to_dist(cls, x: NDArray,
295328
partition=partition,
296329
axis=axis,
297330
local_shapes=local_shapes,
331+
mask=mask,
298332
engine=get_module_name(get_array_module(x)),
299333
dtype=x.dtype)
300334
if partition == Partition.BROADCAST:
@@ -336,6 +370,12 @@ def _check_partition_shape(self, dist_array):
336370
raise ValueError(f"Local Array Shape Mismatch - "
337371
f"{self.local_shape} != {dist_array.local_shape}")
338372

373+
def _check_mask(self, dist_array):
374+
"""Check mask of the Array
375+
"""
376+
if not np.array_equal(self.mask, dist_array.mask):
377+
raise ValueError("Mask of both the arrays must be same")
378+
339379
def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
340380
"""MPI Allreduce operation
341381
"""
@@ -345,12 +385,22 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
345385
self.base_comm.Allreduce(send_buf, recv_buf, op)
346386
return recv_buf
347387

388+
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
389+
"""MPI Allreduce operation with subcommunicator
390+
"""
391+
if recv_buf is None:
392+
return self.sub_comm.allreduce(send_buf, op)
393+
# For MIN and MAX which require recv_buf
394+
self.sub_comm.Allreduce(send_buf, recv_buf, op)
395+
return recv_buf
396+
348397
def __neg__(self):
349398
arr = DistributedArray(global_shape=self.global_shape,
350399
base_comm=self.base_comm,
351400
partition=self.partition,
352401
axis=self.axis,
353402
local_shapes=self.local_shapes,
403+
mask=self.mask,
354404
engine=self.engine,
355405
dtype=self.dtype)
356406
arr[:] = -self.local_array
@@ -378,11 +428,13 @@ def add(self, dist_array):
378428
"""Distributed Addition of arrays
379429
"""
380430
self._check_partition_shape(dist_array)
431+
self._check_mask(dist_array)
381432
SumArray = DistributedArray(global_shape=self.global_shape,
382433
base_comm=self.base_comm,
383434
dtype=self.dtype,
384435
partition=self.partition,
385436
local_shapes=self.local_shapes,
437+
mask=self.mask,
386438
engine=self.engine,
387439
axis=self.axis)
388440
SumArray[:] = self.local_array + dist_array.local_array
@@ -392,6 +444,7 @@ def iadd(self, dist_array):
392444
"""Distributed In-place Addition of arrays
393445
"""
394446
self._check_partition_shape(dist_array)
447+
self._check_mask(dist_array)
395448
self[:] = self.local_array + dist_array.local_array
396449
return self
397450

@@ -400,12 +453,14 @@ def multiply(self, dist_array):
400453
"""
401454
if isinstance(dist_array, DistributedArray):
402455
self._check_partition_shape(dist_array)
456+
self._check_mask(dist_array)
403457

404458
ProductArray = DistributedArray(global_shape=self.global_shape,
405459
base_comm=self.base_comm,
406460
dtype=self.dtype,
407461
partition=self.partition,
408462
local_shapes=self.local_shapes,
463+
mask=self.mask,
409464
engine=self.engine,
410465
axis=self.axis)
411466
if isinstance(dist_array, DistributedArray):
@@ -420,13 +475,15 @@ def dot(self, dist_array):
420475
"""Distributed Dot Product
421476
"""
422477
self._check_partition_shape(dist_array)
478+
self._check_mask(dist_array)
479+
423480
# Convert to Partition.SCATTER if Partition.BROADCAST
424481
x = DistributedArray.to_dist(x=self.local_array) \
425482
if self.partition is Partition.BROADCAST else self
426483
y = DistributedArray.to_dist(x=dist_array.local_array) \
427484
if self.partition is Partition.BROADCAST else dist_array
428485
# Flatten the local arrays and calculate dot product
429-
return self._allreduce(np.dot(x.local_array.flatten(), y.local_array.flatten()))
486+
return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten()))
430487

431488
def _compute_vector_norm(self, local_array: NDArray,
432489
axis: int, ord: Optional[int] = None):
@@ -453,20 +510,20 @@ def _compute_vector_norm(self, local_array: NDArray,
453510
raise ValueError(f"norm-{ord} not possible for vectors")
454511
elif ord == 0:
455512
# Count non-zero then sum reduction
456-
recv_buf = self._allreduce(np.count_nonzero(local_array, axis=axis).astype(np.float64))
513+
recv_buf = self._allreduce_subcomm(np.count_nonzero(local_array, axis=axis).astype(np.float64))
457514
elif ord == np.inf:
458515
# Calculate max followed by max reduction
459-
recv_buf = self._allreduce(np.max(np.abs(local_array), axis=axis).astype(np.float64),
460-
recv_buf, op=MPI.MAX)
516+
recv_buf = self._allreduce_subcomm(np.max(np.abs(local_array), axis=axis).astype(np.float64),
517+
recv_buf, op=MPI.MAX)
461518
recv_buf = np.squeeze(recv_buf, axis=axis)
462519
elif ord == -np.inf:
463520
# Calculate min followed by min reduction
464-
recv_buf = self._allreduce(np.min(np.abs(local_array), axis=axis).astype(np.float64),
465-
recv_buf, op=MPI.MIN)
521+
recv_buf = self._allreduce_subcomm(np.min(np.abs(local_array), axis=axis).astype(np.float64),
522+
recv_buf, op=MPI.MIN)
466523
recv_buf = np.squeeze(recv_buf, axis=axis)
467524

468525
else:
469-
recv_buf = self._allreduce(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis))
526+
recv_buf = self._allreduce_subcomm(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis))
470527
recv_buf = np.power(recv_buf, 1. / ord)
471528
return recv_buf
472529

@@ -500,6 +557,7 @@ def conj(self):
500557
partition=self.partition,
501558
axis=self.axis,
502559
local_shapes=self.local_shapes,
560+
mask=self.mask,
503561
engine=self.engine,
504562
dtype=self.dtype)
505563
conj[:] = self.local_array.conj()
@@ -513,6 +571,7 @@ def copy(self):
513571
partition=self.partition,
514572
axis=self.axis,
515573
local_shapes=self.local_shapes,
574+
mask=self.mask,
516575
engine=self.engine,
517576
dtype=self.dtype)
518577
arr[:] = self.local_array
@@ -535,6 +594,7 @@ def ravel(self, order: Optional[str] = "C"):
535594
local_shapes = [(np.prod(local_shape, axis=-1), ) for local_shape in self.local_shapes]
536595
arr = DistributedArray(global_shape=np.prod(self.global_shape),
537596
local_shapes=local_shapes,
597+
mask=self.mask,
538598
partition=self.partition,
539599
engine=self.engine,
540600
dtype=self.dtype)

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
from scipy.sparse.linalg._interface import _get_dtype
33
from mpi4py import MPI
4-
from typing import Optional, Sequence
4+
from typing import Optional, Sequence, Union, List
5+
from numbers import Integral
56

67
from pylops import LinearOperator
78
from pylops.utils import DTypeLike
@@ -28,6 +29,9 @@ class MPIBlockDiag(MPILinearOperator):
2829
One or more :class:`pylops.LinearOperator` to be stacked.
2930
base_comm : :obj:`mpi4py.MPI.Comm`, optional
3031
Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
32+
mask : :obj:`list`, optional
33+
Mask defining subsets of ranks to consider when performing 'global' operations on
34+
the distributed array such as dot product or norm.
3135
dtype : :obj:`str`, optional
3236
Type of elements in input array.
3337
@@ -95,8 +99,10 @@ class MPIBlockDiag(MPILinearOperator):
9599

96100
def __init__(self, ops: Sequence[LinearOperator],
97101
base_comm: MPI.Comm = MPI.COMM_WORLD,
102+
mask: Optional[List[Integral]] = None,
98103
dtype: Optional[DTypeLike] = None):
99104
self.ops = ops
105+
self.mask = mask
100106
mops = np.zeros(len(self.ops), dtype=np.int64)
101107
nops = np.zeros(len(self.ops), dtype=np.int64)
102108
for iop, oper in enumerate(self.ops):
@@ -116,7 +122,7 @@ def __init__(self, ops: Sequence[LinearOperator],
116122
def _matvec(self, x: DistributedArray) -> DistributedArray:
117123
ncp = get_module(x.engine)
118124
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
119-
engine=x.engine, dtype=self.dtype)
125+
mask=self.mask, engine=x.engine, dtype=self.dtype)
120126
y1 = []
121127
for iop, oper in enumerate(self.ops):
122128
y1.append(oper.matvec(x.local_array[self.mmops[iop]:
@@ -128,7 +134,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
128134
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
129135
ncp = get_module(x.engine)
130136
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
131-
engine=x.engine, dtype=self.dtype)
137+
mask=self.mask, engine=x.engine, dtype=self.dtype)
132138
y1 = []
133139
for iop, oper in enumerate(self.ops):
134140
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]:

pylops_mpi/waveeqprocessing/MDC.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _MDC(G, nt, nv, nfmax, dt=1., dr=1., twosided=True,
2020
2121
Used to be able to provide operators from different libraries to
2222
MDC. It operates in the same way as public method
23-
(PoststackLinearModelling) but has additional input parameters allowing
23+
(MPIMDC) but has additional input parameters allowing
2424
passing a different operator and additional arguments to be passed to such
2525
operator.
2626
@@ -81,8 +81,10 @@ def MPIMDC(G, nt, nv, nfreq, dt=1., dr=1., twosided=True,
8181
base_comm: MPI.Comm = MPI.COMM_WORLD):
8282
r"""Multi-dimensional convolution.
8383
84-
Apply multi-dimensional convolution between two datasets. Model and data
85-
should be provided after flattening 2- or 3-dimensional arrays of size
84+
Apply multi-dimensional convolution between two datasets in a distributed
85+
fashion, with ``G`` distributed over ranks across the frequency axis.
86+
Model and data are broadcasted and should be provided after flattening
87+
2- or 3-dimensional arrays of size
8688
:math:`[n_t \times n_r (\times n_{vs})]` and
8789
:math:`[n_t \times n_s (\times n_{vs})]` (or :math:`2*n_t-1` for
8890
``twosided=True``), respectively.
@@ -91,7 +93,7 @@ def MPIMDC(G, nt, nv, nfreq, dt=1., dr=1., twosided=True,
9193
----------
9294
G : :obj:`numpy.ndarray`
9395
Multi-dimensional convolution kernel in frequency domain of size
94-
:math:`[n_{fmax} \times n_s \times n_r]`
96+
:math:`[n_{f,rank} \times n_s \times n_r]`
9597
nt : :obj:`int`
9698
Number of samples along time axis for model and data (note that this
9799
must be equal to ``2*n_t-1`` when working with ``twosided=True``.

0 commit comments

Comments
 (0)