Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions distarray/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def make_grid_shape(shape, dist, comm_size):
# Trivial case: all processes used for the one distributed dimension.
if comm_size >= shape[distdims[0]]:
dist_grid_shape = (shape[distdims[0]],)
elif (('b' == shape[distdims[0]]) and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this issue show up for multi-dimensional block-distributed arrays for certain sizes? The test here needs to be more general to cover the multi-dimensional cases. If it can be refactored so it handles things on a per-dimension basis, that would be ideal.

For example, say we have a 16 engine cluster running. Currently, we get a lot of empty localarrays with a 6 x 9 block-distributed array:

In [5]: d = Distribution.from_shape(c, (6, 9), ('b', 'b'))

In [6]: d.localshapes()
Out[7]:
[(2, 3),
 (2, 3),
 (2, 3),
 (2, 0),
 (2, 3),
 (2, 3),
 (2, 3),
 (2, 0),
 (2, 3),
 (2, 3),
 (2, 3),
 (2, 0),
 (0, 3),
 (0, 3),
 (0, 3),
 (0, 0)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this needs to be done on a per dimension basis.

check_bad_dims(comm_size, shape[distdims[0]])):
return make_grid_shape(shape, dist, comm_size - 1)
else:
dist_grid_shape = (comm_size,)

Expand Down Expand Up @@ -472,3 +475,16 @@ def shapes_from_dim_data_per_rank(ddpr): # ddpr = dim_data_per_rank
shape.append(size_from_dim_data(dd))
shape_list.append(tuple(shape))
return shape_list


def check_bad_dims(p, d):
""" check for block distributions that woul create empty localarrays
along some axis. For more information see gh-issue 442.
"""
n = d // p
num = d - p + 1
div = p - 1
if (num / div) == n:
return True
else:
return False