Skip to content

Commit 1011b3d

Browse files
author
Kurt Smith
committed
Adds workaround to remove empty local arrays.
Slices the distribution object to remove all empty localarrays before returning. Only done if all dimensions use 'b' or 'n' in the dist tuple.
1 parent ed8d139 commit 1011b3d

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

distarray/dist/maps.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,17 @@ def __new__(cls, context, shape, dist=None, grid_shape=None, targets=None):
612612

613613
# list of `ClientMap` objects, one per dimension.
614614
maps = [map_from_sizes(*args) for args in zip(shape, dist, grid_shape)]
615-
return cls.from_maps(context=context, maps=maps, targets=targets)
615+
616+
self = cls.from_maps(context=context, maps=maps, targets=targets)
617+
618+
# TODO: FIXME: this is a workaround. The reason we slice here is to
619+
# return a distribution with no empty local shapes. The `from_maps()`
620+
# classmethod should be fixed to ensure no empty local arrays are
621+
# created in the first place. That will remove the need to slice the
622+
# distribution to remove empty localshapes.
623+
if all(d in ('n', 'b') for d in self.dist):
624+
self = self.slice((slice(None),)*self.ndim)
625+
return self
616626

617627
@classmethod
618628
def from_global_dim_data(cls, context, global_dim_data, targets=None):

distarray/dist/tests/test_maps.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,5 +300,23 @@ def test_all_n_dist(self):
300300
self.context.ones(distribution)
301301

302302

303+
class TestNoEmptyLocals(ContextTestCase):
304+
305+
def test_no_empty_local_arrays_4_targets(self):
306+
for n in range(1, 20):
307+
dist = Distribution(self.context, shape=(n,),
308+
dist=('b',),
309+
targets=self.context.targets[:4])
310+
for ls in dist.localshapes():
311+
self.assertNotIn(0, ls)
312+
313+
def test_no_empty_local_arrays_3_targets(self):
314+
for n in range(1, 20):
315+
dist = Distribution(self.context, shape=(n,),
316+
dist=('b',),
317+
targets=self.context.targets[:3])
318+
for ls in dist.localshapes():
319+
self.assertNotIn(0, ls)
320+
303321
if __name__ == '__main__':
304322
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)