diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index 426c1c6d..39f212d7 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -519,6 +519,11 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None): self.maps = [_map_from_axis_dim_dicts(axis_dim_dicts) for axis_dim_dicts in axis_dim_dicts_per_axis] + # check for empty localarrays + sizes = self.localsizes() + if 0 in sizes: + raise ValueError("A localarray has zero size") + return self @classmethod @@ -567,6 +572,12 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None, # List of `ClientMap` objects, one per dimension. self.maps = [map_from_sizes(*args) for args in zip(self.shape, self.dist, self.grid_shape)] + + # check for empty localarrays + sizes = self.localsizes() + if 0 in sizes: + raise ValueError("A localarray has zero size") + return self def __init__(self, context, global_dim_data, targets=None): @@ -673,6 +684,11 @@ def __init__(self, context, global_dim_data, targets=None): nelts = reduce(operator.mul, self.grid_shape, 1) self.rank_from_coords = np.arange(nelts).reshape(self.grid_shape) + # check for empty localarrays + sizes = self.localsizes() + if 0 in sizes: + raise ValueError("A localarray has zero size") + def __getitem__(self, idx): return self.maps[idx] @@ -782,3 +798,10 @@ def reduce(self, axes): def localshapes(self): return shapes_from_dim_data_per_rank(self.get_dim_data_per_rank()) + + def localsizes(self): + lshapes = shapes_from_dim_data_per_rank(self.get_dim_data_per_rank()) + sizes = [] + for shape in lshapes: + sizes.append(reduce(operator.mul, shape, 1)) + return tuple(sizes)