Skip to content

Commit 1ae1568

Browse files
committed
feat: improved handling of shapes
1 parent 6d3b1e8 commit 1ae1568

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import Enum
66

77
from pylops.utils import DTypeLike, NDArray
8+
from pylops.utils._internal import _value_or_sized_to_tuple
89
from pylops.utils.backend import get_module, get_array_module, get_module_name
910

1011

@@ -78,7 +79,7 @@ class DistributedArray:
7879
axis : :obj:`int`, optional
7980
Axis along which distribution occurs. Defaults to ``0``.
8081
local_shapes : :obj:`list`, optional
81-
List of tuples representing local shapes at each rank.
82+
List of tuples of integers representing local shapes at each rank.
8283
engine : :obj:`str`, optional
8384
Engine used to store array (``numpy`` or ``cupy``)
8485
dtype : :obj:`str`, optional
@@ -88,7 +89,7 @@ class DistributedArray:
8889
def __init__(self, global_shape: Union[Tuple, Integral],
8990
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
9091
partition: Partition = Partition.SCATTER, axis: int = 0,
91-
local_shapes: Optional[List[Tuple]] = None,
92+
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
9293
engine: Optional[str] = "numpy",
9394
dtype: Optional[DTypeLike] = np.float64):
9495
if isinstance(global_shape, Integral):
@@ -100,10 +101,13 @@ def __init__(self, global_shape: Union[Tuple, Integral],
100101
raise ValueError(f"Should be either {Partition.BROADCAST} "
101102
f"or {Partition.SCATTER}")
102103
self.dtype = dtype
103-
self._global_shape = global_shape
104+
self._global_shape = _value_or_sized_to_tuple(global_shape)
104105
self._base_comm = base_comm
105106
self._partition = partition
106107
self._axis = axis
108+
109+
if local_shapes is not None:
110+
local_shapes = [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
107111
self._check_local_shapes(local_shapes)
108112
self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm,
109113
partition, axis)

0 commit comments

Comments
 (0)