55from enum import Enum
66
77from pylops .utils import DTypeLike , NDArray
8+ from pylops .utils .backend import get_module
89
910
1011class Partition (Enum ):
@@ -78,6 +79,8 @@ class DistributedArray:
7879 Axis along which distribution occurs. Defaults to ``0``.
7980 local_shapes : :obj:`list`, optional
8081 List of tuples representing local shapes at each rank.
82+ engine : :obj:`str`, optional
83+ Engine used to store array (``numpy`` or ``cupy``)
8184 dtype : :obj:`str`, optional
8285 Type of elements in input array. Defaults to ``numpy.float64``.
8386 """
@@ -86,6 +89,7 @@ def __init__(self, global_shape: Union[Tuple, Integral],
8689 base_comm : Optional [MPI .Comm ] = MPI .COMM_WORLD ,
8790 partition : Partition = Partition .SCATTER , axis : int = 0 ,
8891 local_shapes : Optional [List [Tuple ]] = None ,
92+ engine : Optional [str ] = "numpy" ,
8993 dtype : Optional [DTypeLike ] = np .float64 ):
9094 if isinstance (global_shape , Integral ):
9195 global_shape = (global_shape ,)
@@ -103,7 +107,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
103107 self ._check_local_shapes (local_shapes )
104108 self ._local_shape = local_shapes [base_comm .rank ] if local_shapes else local_split (global_shape , base_comm ,
105109 partition , axis )
106- self ._local_array = np .empty (shape = self .local_shape , dtype = self .dtype )
110+ self ._engine = engine
111+ self ._local_array = get_module (engine ).empty (shape = self .local_shape , dtype = self .dtype )
107112
108113 def __getitem__ (self , index ):
109114 return self .local_array [index ]
@@ -160,6 +165,16 @@ def local_shape(self):
160165 """
161166 return self ._local_shape
162167
168+ @property
169+ def engine (self ):
170+ """Engine of the Distributed array
171+
172+ Returns
173+ -------
174+ engine : :obj:`str`
175+ """
176+ return self ._engine
177+
163178 @property
164179 def local_array (self ):
165180 """View of the Local Array
@@ -334,6 +349,7 @@ def __neg__(self):
334349 partition = self .partition ,
335350 axis = self .axis ,
336351 local_shapes = self .local_shapes ,
352+ engine = self .engine ,
337353 dtype = self .dtype )
338354 arr [:] = - self .local_array
339355 return arr
@@ -365,6 +381,7 @@ def add(self, dist_array):
365381 dtype = self .dtype ,
366382 partition = self .partition ,
367383 local_shapes = self .local_shapes ,
384+ engine = self .engine ,
368385 axis = self .axis )
369386 SumArray [:] = self .local_array + dist_array .local_array
370387 return SumArray
@@ -387,6 +404,7 @@ def multiply(self, dist_array):
387404 dtype = self .dtype ,
388405 partition = self .partition ,
389406 local_shapes = self .local_shapes ,
407+ engine = self .engine ,
390408 axis = self .axis )
391409 if isinstance (dist_array , DistributedArray ):
392410 # multiply two DistributedArray
@@ -480,6 +498,7 @@ def conj(self):
480498 partition = self .partition ,
481499 axis = self .axis ,
482500 local_shapes = self .local_shapes ,
501+ engine = self .engine ,
483502 dtype = self .dtype )
484503 conj [:] = self .local_array .conj ()
485504 return conj
@@ -492,6 +511,7 @@ def copy(self):
492511 partition = self .partition ,
493512 axis = self .axis ,
494513 local_shapes = self .local_shapes ,
514+ engine = self .engine ,
495515 dtype = self .dtype )
496516 arr [:] = self .local_array
497517 return arr
0 commit comments