55from enum import Enum
66
77from pylops .utils import DTypeLike , NDArray
8+ from pylops .utils ._internal import _value_or_sized_to_tuple
89from pylops .utils .backend import get_module , get_array_module , get_module_name
910
1011
@@ -78,7 +79,7 @@ class DistributedArray:
7879 axis : :obj:`int`, optional
7980 Axis along which distribution occurs. Defaults to ``0``.
8081 local_shapes : :obj:`list`, optional
81- List of tuples representing local shapes at each rank.
82+ List of tuples of integers representing local shapes at each rank.
8283 engine : :obj:`str`, optional
8384 Engine used to store array (``numpy`` or ``cupy``)
8485 dtype : :obj:`str`, optional
@@ -88,7 +89,7 @@ class DistributedArray:
8889 def __init__ (self , global_shape : Union [Tuple , Integral ],
8990 base_comm : Optional [MPI .Comm ] = MPI .COMM_WORLD ,
9091 partition : Partition = Partition .SCATTER , axis : int = 0 ,
91- local_shapes : Optional [List [Tuple ]] = None ,
92+ local_shapes : Optional [List [Union [ Tuple , Integral ] ]] = None ,
9293 engine : Optional [str ] = "numpy" ,
9394 dtype : Optional [DTypeLike ] = np .float64 ):
9495 if isinstance (global_shape , Integral ):
@@ -100,10 +101,13 @@ def __init__(self, global_shape: Union[Tuple, Integral],
100101 raise ValueError (f"Should be either { Partition .BROADCAST } "
101102 f"or { Partition .SCATTER } " )
102103 self .dtype = dtype
103- self ._global_shape = global_shape
104+ self ._global_shape = _value_or_sized_to_tuple ( global_shape )
104105 self ._base_comm = base_comm
105106 self ._partition = partition
106107 self ._axis = axis
108+
109+ if local_shapes is not None :
110+ local_shapes = [_value_or_sized_to_tuple (local_shape ) for local_shape in local_shapes ]
107111 self ._check_local_shapes (local_shapes )
108112 self ._local_shape = local_shapes [base_comm .rank ] if local_shapes else local_split (global_shape , base_comm ,
109113 partition , axis )
0 commit comments