Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum

from pylops.utils import DTypeLike, NDArray
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.backend import get_module, get_array_module, get_module_name


Expand Down Expand Up @@ -78,7 +79,7 @@ class DistributedArray:
axis : :obj:`int`, optional
Axis along which distribution occurs. Defaults to ``0``.
local_shapes : :obj:`list`, optional
List of tuples representing local shapes at each rank.
List of tuples or integers representing local shapes at each rank.
engine : :obj:`str`, optional
Engine used to store array (``numpy`` or ``cupy``)
dtype : :obj:`str`, optional
Expand All @@ -88,7 +89,7 @@ class DistributedArray:
def __init__(self, global_shape: Union[Tuple, Integral],
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
partition: Partition = Partition.SCATTER, axis: int = 0,
local_shapes: Optional[List[Tuple]] = None,
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
engine: Optional[str] = "numpy",
dtype: Optional[DTypeLike] = np.float64):
if isinstance(global_shape, Integral):
Expand All @@ -100,10 +101,12 @@ def __init__(self, global_shape: Union[Tuple, Integral],
raise ValueError(f"Should be either {Partition.BROADCAST} "
f"or {Partition.SCATTER}")
self.dtype = dtype
self._global_shape = global_shape
self._global_shape = _value_or_sized_to_tuple(global_shape)
self._base_comm = base_comm
self._partition = partition
self._axis = axis

local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
self._check_local_shapes(local_shapes)
self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm,
partition, axis)
Expand Down
Loading