Skip to content

Commit c8098e8

Browse files
committed
feature: enabled use of cupy arrays
1 parent a82d8b1 commit c8098e8

File tree

5 files changed

+46
-15
lines changed

5 files changed

+46
-15
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import Enum
66

77
from pylops.utils import DTypeLike, NDArray
8+
from pylops.utils.backend import get_module
89

910

1011
class Partition(Enum):
@@ -78,6 +79,8 @@ class DistributedArray:
7879
Axis along which distribution occurs. Defaults to ``0``.
7980
local_shapes : :obj:`list`, optional
8081
List of tuples representing local shapes at each rank.
82+
engine : :obj:`str`, optional
83+
Engine used to store array (``numpy`` or ``cupy``)
8184
dtype : :obj:`str`, optional
8285
Type of elements in input array. Defaults to ``numpy.float64``.
8386
"""
@@ -86,6 +89,7 @@ def __init__(self, global_shape: Union[Tuple, Integral],
8689
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
8790
partition: Partition = Partition.SCATTER, axis: int = 0,
8891
local_shapes: Optional[List[Tuple]] = None,
92+
engine: Optional[str] = "numpy",
8993
dtype: Optional[DTypeLike] = np.float64):
9094
if isinstance(global_shape, Integral):
9195
global_shape = (global_shape,)
@@ -103,7 +107,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
103107
self._check_local_shapes(local_shapes)
104108
self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm,
105109
partition, axis)
106-
self._local_array = np.empty(shape=self.local_shape, dtype=self.dtype)
110+
self._engine = engine
111+
self._local_array = get_module(engine).empty(shape=self.local_shape, dtype=self.dtype)
107112

108113
def __getitem__(self, index):
109114
return self.local_array[index]
@@ -160,6 +165,16 @@ def local_shape(self):
160165
"""
161166
return self._local_shape
162167

168+
@property
169+
def engine(self):
170+
"""Engine of the Distributed array
171+
172+
Returns
173+
-------
174+
engine : :obj:`str`
175+
"""
176+
return self._engine
177+
163178
@property
164179
def local_array(self):
165180
"""View of the Local Array
@@ -334,6 +349,7 @@ def __neg__(self):
334349
partition=self.partition,
335350
axis=self.axis,
336351
local_shapes=self.local_shapes,
352+
engine=self.engine,
337353
dtype=self.dtype)
338354
arr[:] = -self.local_array
339355
return arr
@@ -365,6 +381,7 @@ def add(self, dist_array):
365381
dtype=self.dtype,
366382
partition=self.partition,
367383
local_shapes=self.local_shapes,
384+
engine=self.engine,
368385
axis=self.axis)
369386
SumArray[:] = self.local_array + dist_array.local_array
370387
return SumArray
@@ -387,6 +404,7 @@ def multiply(self, dist_array):
387404
dtype=self.dtype,
388405
partition=self.partition,
389406
local_shapes=self.local_shapes,
407+
engine=self.engine,
390408
axis=self.axis)
391409
if isinstance(dist_array, DistributedArray):
392410
# multiply two DistributedArray
@@ -480,6 +498,7 @@ def conj(self):
480498
partition=self.partition,
481499
axis=self.axis,
482500
local_shapes=self.local_shapes,
501+
engine=self.engine,
483502
dtype=self.dtype)
484503
conj[:] = self.local_array.conj()
485504
return conj
@@ -492,6 +511,7 @@ def copy(self):
492511
partition=self.partition,
493512
axis=self.axis,
494513
local_shapes=self.local_shapes,
514+
engine=self.engine,
495515
dtype=self.dtype)
496516
arr[:] = self.local_array
497517
return arr

pylops_mpi/basicoperators/BlockDiag.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from mpi4py import MPI
44
from typing import Optional, Sequence
55

6-
from pylops.utils import DTypeLike
76
from pylops import LinearOperator
7+
from pylops.utils import DTypeLike
8+
from pylops.utils.backend import get_module
89

910
from pylops_mpi import MPILinearOperator, MPIStackedLinearOperator
1011
from pylops_mpi import DistributedArray, StackedDistributedArray
@@ -113,22 +114,26 @@ def __init__(self, ops: Sequence[LinearOperator],
113114

114115
@reshaped(forward=True, stacking=True)
115116
def _matvec(self, x: DistributedArray) -> DistributedArray:
116-
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, dtype=self.dtype)
117+
ncp = get_module(x.engine)
118+
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
119+
engine=x.engine, dtype=self.dtype)
117120
y1 = []
118121
for iop, oper in enumerate(self.ops):
119122
y1.append(oper.matvec(x.local_array[self.mmops[iop]:
120123
self.mmops[iop + 1]]))
121-
y[:] = np.concatenate(y1)
124+
y[:] = ncp.concatenate(ncp.asarray(y1))
122125
return y
123126

124127
@reshaped(forward=False, stacking=True)
125128
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
126-
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m, dtype=self.dtype)
129+
ncp = get_module(x.engine)
130+
y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m,
131+
engine=x.engine, dtype=self.dtype)
127132
y1 = []
128133
for iop, oper in enumerate(self.ops):
129134
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]:
130135
self.nnops[iop + 1]]))
131-
y[:] = np.concatenate(y1)
136+
y[:] = ncp.concatenate(ncp.asarray(y1))
132137
return y
133138

134139

pylops_mpi/basicoperators/VStack.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pylops import LinearOperator
77
from pylops.utils import DTypeLike
8+
from pylops.utils.backend import get_module
89

910
from pylops_mpi import (
1011
MPILinearOperator,
@@ -116,22 +117,26 @@ def __init__(self, ops: Sequence[LinearOperator],
116117
super().__init__(shape=shape, dtype=dtype, base_comm=base_comm)
117118

118119
def _matvec(self, x: DistributedArray) -> DistributedArray:
120+
ncp = get_module(x.engine)
119121
if x.partition is not Partition.BROADCAST:
120122
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
121-
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, dtype=self.dtype)
123+
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
124+
engine=x.engine, dtype=self.dtype)
122125
y1 = []
123126
for iop, oper in enumerate(self.ops):
124127
y1.append(oper.matvec(x.local_array))
125-
y[:] = np.concatenate(y1)
128+
y[:] = ncp.concatenate(y1)
126129
return y
127130

128131
@reshaped(forward=False, stacking=True)
129132
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
130-
y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, dtype=self.dtype)
133+
ncp = get_module(x.engine)
134+
y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST,
135+
engine=x.engine, dtype=self.dtype)
131136
y1 = []
132137
for iop, oper in enumerate(self.ops):
133138
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
134-
y1 = np.sum(y1, axis=0)
139+
y1 = ncp.sum(ncp.asarray(y1), axis=0)
135140
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
136141
return y
137142

pylops_mpi/optimization/cls_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def setup(self,
337337
self.rank = x.rank
338338
self.c = r.copy()
339339
self.q = self.Op.matvec(self.c)
340-
self.kold = np.abs(r.dot(r.conj()))
340+
self.kold = float(np.abs(r.dot(r.conj())))
341341

342342
# create variables to track the residual norm and iterations
343343
self.cost = []
@@ -373,13 +373,13 @@ def step(self, x: Union[DistributedArray, StackedDistributedArray],
373373
374374
"""
375375

376-
a = self.kold / (self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj()))
376+
a = float(self.kold / (self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj())))
377377
x += a * self.c
378378
self.s -= a * self.q
379379
damped_x = self.damp * x
380380
r = self.Op.rmatvec(self.s) - damped_x
381-
k = np.abs(r.dot(r.conj()))
382-
b = k / self.kold
381+
k = float(np.abs(r.dot(r.conj())))
382+
b = float(k / self.kold)
383383
self.c = r + b * self.c
384384
self.q = self.Op.matvec(self.c)
385385
self.kold = k

pylops_mpi/utils/decorators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def wrapper(self, x: DistributedArray):
5454
local_shapes = None
5555
global_shape = getattr(self, "dims")
5656
arr = DistributedArray(global_shape=global_shape,
57-
local_shapes=local_shapes, axis=0, dtype=x.dtype)
57+
local_shapes=local_shapes, axis=0,
58+
engine=x.engine, dtype=x.dtype)
5859
arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape)))
5960
x_local_shapes = np.asarray(x.base_comm.allgather(np.prod(x.local_shape)))
6061
# Calculate num_ghost_cells required for each rank

0 commit comments

Comments
 (0)