Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
71f01ea
Dumbest possible redistribution working.
Jun 11, 2014
0b53ccd
Generalize local redistribution to handle slices.
Jun 12, 2014
6abb248
Generalize plan creation for block redistribution.
Jun 13, 2014
04656ec
Cleanup imports.
Jun 13, 2014
3857d02
Use tuple_intersection().
Jun 13, 2014
5120cab
Fix after rebase.
Jun 17, 2014
fb74fe9
Adds redistribution test for many-to-one.
Jun 17, 2014
0f40f84
Multidimensional block redistribution working.
Jun 17, 2014
868d2bb
More tests and debugging of n-dimensional redistribution.
Jun 19, 2014
a027b77
Test for local-only operation in redistribution.
Jun 20, 2014
e846519
More general block redistribution working.
Jun 20, 2014
e318fb1
Fix API changes after rebase.
Jul 3, 2014
d193354
Merge branch 'master' into feature/block-redistribution
Jul 31, 2014
244031f
Add default value back to `Context.apply()`.
Jul 31, 2014
4f341b7
WIP: on our way towards redistribution working.
Aug 6, 2014
14af63a
Uncomment test, skip it instead.
Aug 6, 2014
6470fd4
Fix `reduce` imports for Py3.
Aug 6, 2014
d3ef7ae
More general redistribution.
Aug 6, 2014
a67c843
Cleanup a bit.
Aug 6, 2014
57d5d05
Merge branch 'master' into feature/block-redistribution
Aug 6, 2014
e7b1a23
Fix bug with NoDist dimensions and redistribution.
Aug 6, 2014
f6f4a5c
Put in TODOs for cleanup / refactoring.
Aug 6, 2014
d3d9162
Merge branch 'master' into feature/block-redistribution
Aug 7, 2014
3b2d540
Merge branch 'master' into feature/block-redistribution
Aug 11, 2014
d164ed8
Redistribution cleanup and refactoring.
Sep 1, 2014
7d5e431
Refactor redistribution tests.
Sep 1, 2014
98ad6ea
Merge branch 'master' into feature/block-redistribution
Sep 1, 2014
8c8bc3e
Remove `default` argument to apply.
Sep 5, 2014
75921e1
Adds docstrings.
Sep 5, 2014
1579b1b
Raise `ValueError` if attempting to redistribute to incompatible size.
Sep 5, 2014
4d3c805
Raise `NotImplementedError` in some redistribution cases.
Sep 5, 2014
f61e439
Fix target renaming in Makefile.
Sep 5, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ install:
# Testing-related targets.
# ----------------------------------------------------------------------------

test_client:
test_ipython:
${PYTHON} -m unittest discover -c
.PHONY: test_client
.PHONY: test_ipython

test_client_with_coverage:
test_ipython_with_coverage:
${COVERAGE} run -pm unittest discover -cv
.PHONY: test_client_with_coverage
.PHONY: test_ipython_with_coverage

${PARALLEL_OUT_DIR} :
mkdir ${PARALLEL_OUT_DIR}
Expand Down Expand Up @@ -95,10 +95,10 @@ test_mpi_with_coverage:
${MPI_ONLY_LAUNCH_TEST}
.PHONY: test_mpi_with_coverage

test: test_client test_engines test_mpi
test: test_ipython test_mpi test_engines
.PHONY: test

test_with_coverage: test_client_with_coverage test_engines_with_coverage test_mpi_with_coverage
test_with_coverage: test_ipython_with_coverage test_mpi_with_coverage test_engines_with_coverage
.PHONY: test_with_coverage

coverage_report:
Expand Down
14 changes: 2 additions & 12 deletions distarray/globalapi/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,7 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize):
"""
from importlib import import_module
import types
from distarray.metadata_utils import arg_kwarg_proxy_converter
from distarray.localapi import LocalArray

main = import_module('__main__')
Expand All @@ -793,19 +794,8 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize):
func = types.FunctionType(func_code, new_func_globals,
func_name, func_defaults,
func_closure)
# convert args
args = list(args)
for i, a in enumerate(args):
if isinstance(a, main.Proxy):
args[i] = a.dereference()
args = tuple(args)

# convert kwargs
for k in kwargs.keys():
val = kwargs[k]
if isinstance(val, main.Proxy):
kwargs[k] = val.dereference()

args, kwargs = arg_kwarg_proxy_converter(args, kwargs)
result = func(*args, **kwargs)

if autoproxyize and isinstance(result, LocalArray):
Expand Down
61 changes: 60 additions & 1 deletion distarray/globalapi/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import distarray.localapi
from distarray.metadata_utils import sanitize_indices
from distarray.globalapi.maps import Distribution
from distarray.globalapi.maps import Distribution, asdistribution
from distarray.utils import _raise_nie
from distarray.metadata_utils import normalize_reduction_axes

Expand Down Expand Up @@ -495,6 +495,65 @@ def local_view(larr, ddpr, dtype):
return DistArray.from_localarrays(key=new_key, distribution=new_dist,
dtype=dtype)

def distribute_as(self, shape_or_dist):
"""
Redistributes this DistArray, returning a new DistArray with the same
data and corresponding distribution.

Parameters
----------
shape_or_dist : shape tuple or Distribution object.
Distribution for the new DistArray. The new distribution must have
the same number of items as this distarray. The global shape and
targets may be different. If shape tuple, immediately converted to
a Distribution object with default parameters.

Returns
-------
DistArray
A new DistArray distributed according to `dist`.

Note
----
Currently implemented for block and non-distributed maps only.

"""

dist = asdistribution(self.context, shape_or_dist)

if (any(d not in ('b', 'n') for d in self.distribution.dist) or
any(d not in ('b', 'n') for d in dist.dist)):
msg = "Only block and non-distributed dimensions currently supported."
raise NotImplementedError(msg)

def _local_redistribute_same_shape(comm, plan, la_from, la_to):
from distarray.localapi import redistribute
redistribute(comm, plan, la_from, la_to)

def _local_redistribute_general(comm, plan, la_from, la_to):
from distarray.localapi import redistribute_general
redistribute_general(comm, plan, la_from, la_to)

source_size = self.global_size
dest_size = reduce(operator.mul, dist.shape, 1)

if self.distribution.shape == dist.shape:
_local_redistribute = _local_redistribute_same_shape
elif source_size == dest_size:
_local_redistribute = _local_redistribute_general
else:
msg = ("Original size %d != new size %d,"
" and total size of new array must be unchanged.")
raise ValueError(msg % (source_size, dest_size))

plan = self.distribution.get_redist_plan(dist)
ubercomm, all_targets = self.distribution.comm_union(dist)
result = DistArray(dist, dtype=self.dtype)

self.context.apply(_local_redistribute, (ubercomm, plan, self.key, result.key),
targets=all_targets)
return result

# Binary operators

def _binary_op_from_ufunc(self, other, func, rop_str=None, *args, **kwargs):
Expand Down
149 changes: 147 additions & 2 deletions distarray/globalapi/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
sanitize_indices,
_start_stop_block,
tuple_intersection,
shapes_from_dim_data_per_rank)
shapes_from_dim_data_per_rank,
condense,
strides_from_shape)


def _dedup_dim_dicts(dim_dicts):
Expand Down Expand Up @@ -551,7 +553,7 @@ def from_maps(cls, context, maps, targets=None):
self = super(Distribution, cls).__new__(cls)
self.context = context
self.targets = sorted(targets or context.targets)
self.comm = self.context.make_subcomm(self.targets)
self._comm = None
self.maps = maps
self.shape = tuple(m.size for m in self.maps)
self.ndim = len(self.maps)
Expand Down Expand Up @@ -758,6 +760,12 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.maps)

@property
def comm(self):
if self._comm is None:
self._comm = self.context.make_subcomm(self.targets)
return self._comm

@property
def has_precise_index(self):
"""
Expand Down Expand Up @@ -869,3 +877,140 @@ def view(self, new_dimsize=None):

def localshapes(self):
return shapes_from_dim_data_per_rank(self.get_dim_data_per_rank())

def comm_union(self, *dists):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring might be nice, though maybe this should be _comm_union instead.

"""
Make a communicator that includes the union of all targets in `dists`.

Parameters
----------
dists: sequence of distribution objects.

Returns
-------
tuple
First element is encompassing communicator proxy; second is a
sequence of all targets in `dists`.

"""
dist_targets = [d.targets for d in dists]
all_targets = sorted(reduce(set.union, dist_targets, set(self.targets)))
return self.context.make_subcomm(all_targets), all_targets

# ------------------------------------------------------------------------
# Redistribution
# ------------------------------------------------------------------------

@staticmethod
def _redist_intersection_same_shape(source_dimdata, dest_dimdata):

intersections = []
for source_dimdict, dest_dimdict in zip(source_dimdata, dest_dimdata):

if not (source_dimdict['dist_type'] ==
dest_dimdict['dist_type'] == 'b'):
raise ValueError("Only 'b' dist_type supported")

source_idxs = source_dimdict['start'], source_dimdict['stop']
dest_idxs = dest_dimdict['start'], dest_dimdict['stop']

intersections.append(tuple_intersection(source_idxs, dest_idxs))

return intersections

@staticmethod
def _redist_intersection_reshape(source_dimdata, dest_dimdata):
source_flat = global_flat_indices(source_dimdata)
dest_flat = global_flat_indices(dest_dimdata)
return _global_flat_indices_intersection(source_flat, dest_flat)

def get_redist_plan(self, other_dist):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring

# Get all targets
all_targets = sorted(set(self.targets + other_dist.targets))
union_rank_from_target = {t: r for (r, t) in enumerate(all_targets)}

source_ranks = range(len(self.targets))
source_targets = self.targets
union_rank_from_source_rank = {sr: union_rank_from_target[st]
for (sr, st) in
zip(source_ranks, source_targets)}

dest_ranks = range(len(other_dist.targets))
dest_targets = other_dist.targets
union_rank_from_dest_rank = {sr: union_rank_from_target[st]
for (sr, st) in
zip(dest_ranks, dest_targets)}

source_ddpr = self.get_dim_data_per_rank()
dest_ddpr = other_dist.get_dim_data_per_rank()
source_dest_pairs = product(source_ddpr, dest_ddpr)

if self.shape == other_dist.shape:
_intersection = Distribution._redist_intersection_same_shape
else:
_intersection = Distribution._redist_intersection_reshape

plan = []
for source_dd, dest_dd in source_dest_pairs:
intersections = _intersection(source_dd, dest_dd)
if intersections and all(i for i in intersections):
source_coords = tuple(dd['proc_grid_rank'] for dd in source_dd)
source_rank = self.rank_from_coords[source_coords]
dest_coords = tuple(dd['proc_grid_rank'] for dd in dest_dd)
dest_rank = other_dist.rank_from_coords[dest_coords]
plan.append({
'source_rank': union_rank_from_source_rank[source_rank],
'dest_rank': union_rank_from_dest_rank[dest_rank],
'indices': intersections,
}
)

return plan


# ----------------------------------------------------------------------------
# Redistribution helper functions.
# ----------------------------------------------------------------------------

def global_flat_indices(dim_data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring

"""
Return a list of tuples of indices into the flattened global array.

Parameters
----------
dim_data: dimension dictionary.

Returns
-------
list of 2-tuples of ints.
Each tuple is a (start, stop) interval into the flattened global array.
All selected ranges comprise the indices for this dim_data's sub-array.

"""
# TODO: FIXME: can be optimized when the last dimension is 'n'.

for dd in dim_data:
if dd['dist_type'] == 'n':
dd['start'] = 0
dd['stop'] = dd['size']

glb_shape = tuple(dd['size'] for dd in dim_data)
glb_strides = strides_from_shape(glb_shape)

ranges = [range(dd['start'], dd['stop']) for dd in dim_data[:-1]]
start_ranges = ranges + [[dim_data[-1]['start']]]
stop_ranges = ranges + [[dim_data[-1]['stop']]]

def flatten(idx):
return sum(a * b for (a, b) in zip(idx, glb_strides))

starts = map(flatten, product(*start_ranges))
stops = map(flatten, product(*stop_ranges))

intervals = zip(starts, stops)
return condense(intervals)

def _global_flat_indices_intersection(gfis0, gfis1):
intersections = filter(None, [tuple_intersection(a, b)
for (a, b) in product(gfis0, gfis1)])
return [i[:2] for i in intersections]
Loading