Skip to content

Commit c071678

Browse files
committed
Merge pull request #396 from cowlicks/recurs-bug
Fix bug wth all 'n's dist tuple.
2 parents f284652 + b2e6c31 commit c071678

File tree

5 files changed

+58
-31
lines changed

5 files changed

+58
-31
lines changed

distarray/dist/maps.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
normalize_grid_shape,
3737
make_grid_shape,
3838
positivify,
39-
validate_grid_shape,
4039
_start_stop_block,
4140
normalize_dim_dict,
4241
normalize_reduction_axes)
@@ -412,8 +411,8 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None):
412411
self.ndim = len(dd0)
413412
self.dist = tuple(dd['dist_type'] for dd in dd0)
414413
self.grid_shape = tuple(dd['proc_grid_size'] for dd in dd0)
415-
416-
validate_grid_shape(self.grid_shape, self.dist, len(self.targets))
414+
self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim,
415+
self.dist, len(self.targets))
417416

418417
coords = [tuple(d['proc_grid_rank'] for d in dd) for dd in
419418
dim_data_per_rank]
@@ -439,7 +438,18 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None):
439438
return self
440439

441440
@classmethod
442-
def from_shape(cls, context, shape, dist=None, grid_shape=None, targets=None):
441+
def from_shape(cls, context, shape, dist=None, grid_shape=None,
442+
targets=None):
443+
444+
# special case when dist is all 'n's.
445+
if (dist is not None) and all(d == 'n' for d in dist):
446+
if (targets is not None) and (len(targets) != 1):
447+
raise ValueError('target dist conflict')
448+
elif targets is None:
449+
targets = [context.targets[0]]
450+
else:
451+
# then targets is set correctly
452+
pass
443453

444454
self = cls.__new__(cls)
445455
self.context = context
@@ -448,17 +458,18 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None, targets=None):
448458
self.shape = shape
449459
self.ndim = len(shape)
450460

461+
# dist
451462
if dist is None:
452463
dist = {0: 'b'}
453464
self.dist = normalize_dist(dist, self.ndim)
454465

455-
if grid_shape is None: # Make a new grid_shape if not provided.
456-
self.grid_shape = make_grid_shape(self.shape, self.dist,
457-
len(self.targets))
458-
else: # Otherwise normalize the one passed in.
459-
self.grid_shape = normalize_grid_shape(grid_shape, self.ndim)
460-
# In either case, validate.
461-
validate_grid_shape(self.grid_shape, self.dist, len(self.targets))
466+
# grid_shape
467+
if grid_shape is None:
468+
grid_shape = make_grid_shape(self.shape, self.dist,
469+
len(self.targets))
470+
471+
self.grid_shape = normalize_grid_shape(grid_shape, self.ndim,
472+
self.dist, len(self.targets))
462473

463474
# TODO: FIXME: assert that self.rank_from_coords is valid and conforms
464475
# to how MPI does it.
@@ -568,7 +579,8 @@ def __init__(self, context, global_dim_data, targets=None):
568579
self.dist = tuple(m.dist for m in self.maps)
569580
self.grid_shape = tuple(m.grid_size for m in self.maps)
570581

571-
validate_grid_shape(self.grid_shape, self.dist, len(context.targets))
582+
self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim,
583+
self.dist, len(context.targets))
572584

573585
nelts = reduce(operator.mul, self.grid_shape, 1)
574586
self.rank_from_coords = np.arange(nelts).reshape(self.grid_shape)

distarray/dist/tests/test_maps.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_is_compatible(self):
7171

7272
self.assertTrue(cm0.is_compatible(cm1))
7373
self.assertTrue(cm1.is_compatible(cm0))
74-
74+
7575
nr -= 1; nc -= 1; nd -= 1
7676

7777
cm2 = client_map.Distribution.from_shape(
@@ -112,3 +112,15 @@ def test_reduce_0D(self):
112112
self.assertSequenceEqual(new_dist.shape, ())
113113
self.assertEqual(new_dist.grid_shape, ())
114114
self.assertEqual(set(new_dist.targets), set(dist.targets[:1]))
115+
116+
117+
class TestDistributionCreation(unittest.TestCase):
118+
def test_all_n_dist(self):
119+
context = Context()
120+
distribution = client_map.Distribution.from_shape(context,
121+
shape=(3, 3),
122+
dist=('n', 'n'))
123+
context.ones(distribution)
124+
125+
if __name__ == '__main__':
126+
unittest.main(verbosity=2)

distarray/local/maps.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from distarray.externals.six.moves import range, zip
2828

2929
from distarray.local import construct
30-
from distarray.metadata_utils import (validate_grid_shape, make_grid_shape,
31-
normalize_grid_shape, normalize_dist,
32-
distribute_indices, positivify)
30+
from distarray.metadata_utils import (make_grid_shape, normalize_grid_shape,
31+
normalize_dist, distribute_indices,
32+
positivify)
3333

3434

3535
class Distribution(object):
@@ -56,10 +56,8 @@ def from_shape(cls, shape, dist=None, grid_shape=None, comm=None):
5656

5757
if grid_shape is None: # Make a new grid_shape if not provided.
5858
grid_shape = make_grid_shape(shape, dist_tuple, comm_size)
59-
else: # Otherwise normalize the one passed in.
60-
grid_shape = normalize_grid_shape(grid_shape, ndim)
61-
# In either case, validate.
62-
validate_grid_shape(grid_shape, dist_tuple, comm_size)
59+
grid_shape = normalize_grid_shape(grid_shape, ndim,
60+
dist_tuple, comm_size)
6361

6462
comm = construct.init_comm(base_comm, grid_shape)
6563
grid_coords = comm.Get_coords(comm.Get_rank())
@@ -284,7 +282,7 @@ class BlockCyclicMap(MapBase):
284282
"""
285283

286284
dist = 'c'
287-
285+
288286
def __init__(self, global_size, grid_size, grid_rank, start, block_size):
289287
if start % block_size:
290288
msg = "Value of start (%r) does not evenly divide block_size (%r)."

distarray/local/tests/paralleltest_maps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_grid_shape(self):
217217
dist = Distribution.from_shape((20, 20), dist='b', comm=self.comm)
218218
self.assertEqual(dist.grid_shape, (12, 1))
219219
dist = Distribution.from_shape((2*10, 6*10), dist=('b', 'b'),
220-
comm=self.comm)
220+
comm=self.comm)
221221
self.assertEqual(dist.grid_shape, (2, 6))
222222
dist = Distribution.from_shape((6*10, 2*10), dist='bb', comm=self.comm)
223223
self.assertEqual(dist.grid_shape, (6, 2))

distarray/metadata_utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@ class GridShapeError(Exception):
2424
pass
2525

2626

27-
def normalize_grid_shape(grid_shape, ndims):
28-
"""Adds 1s to grid_shape so it has `ndims` dimensions."""
29-
return tuple(grid_shape) + (1,) * (ndims - len(grid_shape))
27+
def normalize_grid_shape(grid_shape, ndims, dist, comm_size):
28+
"""Adds 1s to grid_shape so it has `ndims` dimensions. Validates
29+
`grid_shape` tuple against the `dist` tuple and `comm_size`.
30+
"""
31+
grid_shape = tuple(grid_shape) + (1,) * (ndims - len(grid_shape))
3032

33+
# short circuit for special case
34+
if all(x == 'n' for x in dist):
35+
if not all(x == 1 for x in grid_shape):
36+
raise ValueError("grid shape should be all `1`'s not %s." %
37+
grid_shape)
38+
return grid_shape
3139

32-
def validate_grid_shape(grid_shape, dist, comm_size):
33-
""" Validates `grid_shape` tuple against the `dist` tuple and
34-
`comm_size`.
35-
"""
3640
if len(grid_shape) != len(dist):
3741
msg = "grid_shape's length (%d) not equal to dist's length (%d)"
3842
raise InvalidGridShapeError(msg % (len(grid_shape), len(dist)))
@@ -74,7 +78,9 @@ def make_grid_shape(shape, dist, comm_size):
7478
distdims = tuple(i for (i, v) in enumerate(dist) if v != 'n')
7579
ndistdim = len(distdims)
7680

77-
if ndistdim == 1:
81+
if ndistdim == 0:
82+
dist_grid_shape = ()
83+
elif ndistdim == 1:
7884
# Trivial case: all processes used for the one distributed dimension.
7985
dist_grid_shape = (comm_size,)
8086

@@ -223,4 +229,3 @@ def normalize_reduction_axes(axes, ndim):
223229
else:
224230
axes = tuple(positivify(a, ndim) for a in axes)
225231
return axes
226-

0 commit comments

Comments
 (0)