55from enum import Enum
66
77from pylops .utils import DTypeLike , NDArray
8+ from pylops .utils .backend import get_module , get_array_module , get_module_name
89
910
1011class Partition (Enum ):
@@ -78,6 +79,8 @@ class DistributedArray:
7879 Axis along which distribution occurs. Defaults to ``0``.
7980 local_shapes : :obj:`list`, optional
8081 List of tuples representing local shapes at each rank.
82+ engine : :obj:`str`, optional
83+ Engine used to store array (``numpy`` or ``cupy``)
8184 dtype : :obj:`str`, optional
8285 Type of elements in input array. Defaults to ``numpy.float64``.
8386 """
@@ -86,6 +89,7 @@ def __init__(self, global_shape: Union[Tuple, Integral],
8689 base_comm : Optional [MPI .Comm ] = MPI .COMM_WORLD ,
8790 partition : Partition = Partition .SCATTER , axis : int = 0 ,
8891 local_shapes : Optional [List [Tuple ]] = None ,
92+ engine : Optional [str ] = "numpy" ,
8993 dtype : Optional [DTypeLike ] = np .float64 ):
9094 if isinstance (global_shape , Integral ):
9195 global_shape = (global_shape ,)
@@ -103,7 +107,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
103107 self ._check_local_shapes (local_shapes )
104108 self ._local_shape = local_shapes [base_comm .rank ] if local_shapes else local_split (global_shape , base_comm ,
105109 partition , axis )
106- self ._local_array = np .empty (shape = self .local_shape , dtype = self .dtype )
110+ self ._engine = engine
111+ self ._local_array = get_module (engine ).empty (shape = self .local_shape , dtype = self .dtype )
107112
108113 def __getitem__ (self , index ):
109114 return self .local_array [index ]
@@ -160,6 +165,16 @@ def local_shape(self):
160165 """
161166 return self ._local_shape
162167
168+ @property
169+ def engine (self ):
170+ """Engine of the Distributed array
171+
172+ Returns
173+ -------
174+ engine : :obj:`str`
175+ """
176+ return self ._engine
177+
163178 @property
164179 def local_array (self ):
165180 """View of the Local Array
@@ -269,6 +284,7 @@ def to_dist(cls, x: NDArray,
269284 Axis of Distribution
270285 local_shapes : :obj:`list`, optional
271286 Local Shapes at each rank.
287+
272288 Returns
273289 ----------
274290 dist_array : :obj:`DistributedArray`
@@ -279,6 +295,7 @@ def to_dist(cls, x: NDArray,
279295 partition = partition ,
280296 axis = axis ,
281297 local_shapes = local_shapes ,
298+ engine = get_module_name (get_array_module (x )),
282299 dtype = x .dtype )
283300 if partition == Partition .BROADCAST :
284301 dist_array [:] = x
@@ -334,6 +351,7 @@ def __neg__(self):
334351 partition = self .partition ,
335352 axis = self .axis ,
336353 local_shapes = self .local_shapes ,
354+ engine = self .engine ,
337355 dtype = self .dtype )
338356 arr [:] = - self .local_array
339357 return arr
@@ -365,6 +383,7 @@ def add(self, dist_array):
365383 dtype = self .dtype ,
366384 partition = self .partition ,
367385 local_shapes = self .local_shapes ,
386+ engine = self .engine ,
368387 axis = self .axis )
369388 SumArray [:] = self .local_array + dist_array .local_array
370389 return SumArray
@@ -387,6 +406,7 @@ def multiply(self, dist_array):
387406 dtype = self .dtype ,
388407 partition = self .partition ,
389408 local_shapes = self .local_shapes ,
409+ engine = self .engine ,
390410 axis = self .axis )
391411 if isinstance (dist_array , DistributedArray ):
392412 # multiply two DistributedArray
@@ -480,6 +500,7 @@ def conj(self):
480500 partition = self .partition ,
481501 axis = self .axis ,
482502 local_shapes = self .local_shapes ,
503+ engine = self .engine ,
483504 dtype = self .dtype )
484505 conj [:] = self .local_array .conj ()
485506 return conj
@@ -492,6 +513,7 @@ def copy(self):
492513 partition = self .partition ,
493514 axis = self .axis ,
494515 local_shapes = self .local_shapes ,
516+ engine = self .engine ,
495517 dtype = self .dtype )
496518 arr [:] = self .local_array
497519 return arr
@@ -514,6 +536,7 @@ def ravel(self, order: Optional[str] = "C"):
514536 arr = DistributedArray (global_shape = np .prod (self .global_shape ),
515537 local_shapes = local_shapes ,
516538 partition = self .partition ,
539+ engine = self .engine ,
517540 dtype = self .dtype )
518541 local_array = np .ravel (self .local_array , order = order )
519542 x = local_array .copy ()
0 commit comments