Skip to content

Commit a3ec96b

Browse files
committed
Merge pull request #584 from enthought/feature/block-redistribution
Block redistribution
2 parents 9ed2fba + f61e439 commit a3ec96b

File tree

10 files changed

+648
-43
lines changed

10 files changed

+648
-43
lines changed

Makefile

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ install:
6060
# Testing-related targets.
6161
# ----------------------------------------------------------------------------
6262

63-
test_client:
63+
test_ipython:
6464
${PYTHON} -m unittest discover -c
65-
.PHONY: test_client
65+
.PHONY: test_ipython
6666

67-
test_client_with_coverage:
67+
test_ipython_with_coverage:
6868
${COVERAGE} run -pm unittest discover -cv
69-
.PHONY: test_client_with_coverage
69+
.PHONY: test_ipython_with_coverage
7070

7171
${PARALLEL_OUT_DIR} :
7272
mkdir ${PARALLEL_OUT_DIR}
@@ -95,10 +95,10 @@ test_mpi_with_coverage:
9595
${MPI_ONLY_LAUNCH_TEST}
9696
.PHONY: test_mpi_with_coverage
9797

98-
test: test_client test_engines test_mpi
98+
test: test_ipython test_mpi test_engines
9999
.PHONY: test
100100

101-
test_with_coverage: test_client_with_coverage test_engines_with_coverage test_mpi_with_coverage
101+
test_with_coverage: test_ipython_with_coverage test_mpi_with_coverage test_engines_with_coverage
102102
.PHONY: test_with_coverage
103103

104104
coverage_report:

distarray/globalapi/context.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,7 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize):
772772
"""
773773
from importlib import import_module
774774
import types
775+
from distarray.metadata_utils import arg_kwarg_proxy_converter
775776
from distarray.localapi import LocalArray
776777

777778
main = import_module('__main__')
@@ -793,19 +794,8 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize):
793794
func = types.FunctionType(func_code, new_func_globals,
794795
func_name, func_defaults,
795796
func_closure)
796-
# convert args
797-
args = list(args)
798-
for i, a in enumerate(args):
799-
if isinstance(a, main.Proxy):
800-
args[i] = a.dereference()
801-
args = tuple(args)
802-
803-
# convert kwargs
804-
for k in kwargs.keys():
805-
val = kwargs[k]
806-
if isinstance(val, main.Proxy):
807-
kwargs[k] = val.dereference()
808797

798+
args, kwargs = arg_kwarg_proxy_converter(args, kwargs)
809799
result = func(*args, **kwargs)
810800

811801
if autoproxyize and isinstance(result, LocalArray):

distarray/globalapi/distarray.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import distarray.localapi
2626
from distarray.metadata_utils import sanitize_indices
27-
from distarray.globalapi.maps import Distribution
27+
from distarray.globalapi.maps import Distribution, asdistribution
2828
from distarray.utils import _raise_nie
2929
from distarray.metadata_utils import normalize_reduction_axes
3030

@@ -495,6 +495,65 @@ def local_view(larr, ddpr, dtype):
495495
return DistArray.from_localarrays(key=new_key, distribution=new_dist,
496496
dtype=dtype)
497497

498+
def distribute_as(self, shape_or_dist):
499+
"""
500+
Redistributes this DistArray, returning a new DistArray with the same
501+
data and corresponding distribution.
502+
503+
Parameters
504+
----------
505+
shape_or_dist : shape tuple or Distribution object.
506+
Distribution for the new DistArray. The new distribution must have
507+
the same number of items as this distarray. The global shape and
508+
targets may be different. If shape tuple, immediately converted to
509+
a Distribution object with default parameters.
510+
511+
Returns
512+
-------
513+
DistArray
514+
A new DistArray distributed according to `dist`.
515+
516+
Note
517+
----
518+
Currently implemented for block and non-distributed maps only.
519+
520+
"""
521+
522+
dist = asdistribution(self.context, shape_or_dist)
523+
524+
if (any(d not in ('b', 'n') for d in self.distribution.dist) or
525+
any(d not in ('b', 'n') for d in dist.dist)):
526+
msg = "Only block and non-distributed dimensions currently supported."
527+
raise NotImplementedError(msg)
528+
529+
def _local_redistribute_same_shape(comm, plan, la_from, la_to):
530+
from distarray.localapi import redistribute
531+
redistribute(comm, plan, la_from, la_to)
532+
533+
def _local_redistribute_general(comm, plan, la_from, la_to):
534+
from distarray.localapi import redistribute_general
535+
redistribute_general(comm, plan, la_from, la_to)
536+
537+
source_size = self.global_size
538+
dest_size = reduce(operator.mul, dist.shape, 1)
539+
540+
if self.distribution.shape == dist.shape:
541+
_local_redistribute = _local_redistribute_same_shape
542+
elif source_size == dest_size:
543+
_local_redistribute = _local_redistribute_general
544+
else:
545+
msg = ("Original size %d != new size %d,"
546+
" and total size of new array must be unchanged.")
547+
raise ValueError(msg % (source_size, dest_size))
548+
549+
plan = self.distribution.get_redist_plan(dist)
550+
ubercomm, all_targets = self.distribution.comm_union(dist)
551+
result = DistArray(dist, dtype=self.dtype)
552+
553+
self.context.apply(_local_redistribute, (ubercomm, plan, self.key, result.key),
554+
targets=all_targets)
555+
return result
556+
498557
# Binary operators
499558

500559
def _binary_op_from_ufunc(self, other, func, rop_str=None, *args, **kwargs):

distarray/globalapi/maps.py

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
sanitize_indices,
4242
_start_stop_block,
4343
tuple_intersection,
44-
shapes_from_dim_data_per_rank)
44+
shapes_from_dim_data_per_rank,
45+
condense,
46+
strides_from_shape)
4547

4648

4749
def _dedup_dim_dicts(dim_dicts):
@@ -551,7 +553,7 @@ def from_maps(cls, context, maps, targets=None):
551553
self = super(Distribution, cls).__new__(cls)
552554
self.context = context
553555
self.targets = sorted(targets or context.targets)
554-
self.comm = self.context.make_subcomm(self.targets)
556+
self._comm = None
555557
self.maps = maps
556558
self.shape = tuple(m.size for m in self.maps)
557559
self.ndim = len(self.maps)
@@ -758,6 +760,12 @@ def __getitem__(self, idx):
758760
def __len__(self):
759761
return len(self.maps)
760762

763+
@property
764+
def comm(self):
765+
if self._comm is None:
766+
self._comm = self.context.make_subcomm(self.targets)
767+
return self._comm
768+
761769
@property
762770
def has_precise_index(self):
763771
"""
@@ -869,3 +877,140 @@ def view(self, new_dimsize=None):
869877

870878
def localshapes(self):
871879
return shapes_from_dim_data_per_rank(self.get_dim_data_per_rank())
880+
881+
def comm_union(self, *dists):
882+
"""
883+
Make a communicator that includes the union of all targets in `dists`.
884+
885+
Parameters
886+
----------
887+
dists: sequence of distribution objects.
888+
889+
Returns
890+
-------
891+
tuple
892+
First element is encompassing communicator proxy; second is a
893+
sequence of all targets in `dists`.
894+
895+
"""
896+
dist_targets = [d.targets for d in dists]
897+
all_targets = sorted(reduce(set.union, dist_targets, set(self.targets)))
898+
return self.context.make_subcomm(all_targets), all_targets
899+
900+
# ------------------------------------------------------------------------
901+
# Redistribution
902+
# ------------------------------------------------------------------------
903+
904+
@staticmethod
905+
def _redist_intersection_same_shape(source_dimdata, dest_dimdata):
906+
907+
intersections = []
908+
for source_dimdict, dest_dimdict in zip(source_dimdata, dest_dimdata):
909+
910+
if not (source_dimdict['dist_type'] ==
911+
dest_dimdict['dist_type'] == 'b'):
912+
raise ValueError("Only 'b' dist_type supported")
913+
914+
source_idxs = source_dimdict['start'], source_dimdict['stop']
915+
dest_idxs = dest_dimdict['start'], dest_dimdict['stop']
916+
917+
intersections.append(tuple_intersection(source_idxs, dest_idxs))
918+
919+
return intersections
920+
921+
@staticmethod
922+
def _redist_intersection_reshape(source_dimdata, dest_dimdata):
923+
source_flat = global_flat_indices(source_dimdata)
924+
dest_flat = global_flat_indices(dest_dimdata)
925+
return _global_flat_indices_intersection(source_flat, dest_flat)
926+
927+
def get_redist_plan(self, other_dist):
928+
# Get all targets
929+
all_targets = sorted(set(self.targets + other_dist.targets))
930+
union_rank_from_target = {t: r for (r, t) in enumerate(all_targets)}
931+
932+
source_ranks = range(len(self.targets))
933+
source_targets = self.targets
934+
union_rank_from_source_rank = {sr: union_rank_from_target[st]
935+
for (sr, st) in
936+
zip(source_ranks, source_targets)}
937+
938+
dest_ranks = range(len(other_dist.targets))
939+
dest_targets = other_dist.targets
940+
union_rank_from_dest_rank = {sr: union_rank_from_target[st]
941+
for (sr, st) in
942+
zip(dest_ranks, dest_targets)}
943+
944+
source_ddpr = self.get_dim_data_per_rank()
945+
dest_ddpr = other_dist.get_dim_data_per_rank()
946+
source_dest_pairs = product(source_ddpr, dest_ddpr)
947+
948+
if self.shape == other_dist.shape:
949+
_intersection = Distribution._redist_intersection_same_shape
950+
else:
951+
_intersection = Distribution._redist_intersection_reshape
952+
953+
plan = []
954+
for source_dd, dest_dd in source_dest_pairs:
955+
intersections = _intersection(source_dd, dest_dd)
956+
if intersections and all(i for i in intersections):
957+
source_coords = tuple(dd['proc_grid_rank'] for dd in source_dd)
958+
source_rank = self.rank_from_coords[source_coords]
959+
dest_coords = tuple(dd['proc_grid_rank'] for dd in dest_dd)
960+
dest_rank = other_dist.rank_from_coords[dest_coords]
961+
plan.append({
962+
'source_rank': union_rank_from_source_rank[source_rank],
963+
'dest_rank': union_rank_from_dest_rank[dest_rank],
964+
'indices': intersections,
965+
}
966+
)
967+
968+
return plan
969+
970+
971+
# ----------------------------------------------------------------------------
972+
# Redistribution helper functions.
973+
# ----------------------------------------------------------------------------
974+
975+
def global_flat_indices(dim_data):
976+
"""
977+
Return a list of tuples of indices into the flattened global array.
978+
979+
Parameters
980+
----------
981+
dim_data: dimension dictionary.
982+
983+
Returns
984+
-------
985+
list of 2-tuples of ints.
986+
Each tuple is a (start, stop) interval into the flattened global array.
987+
All selected ranges comprise the indices for this dim_data's sub-array.
988+
989+
"""
990+
# TODO: FIXME: can be optimized when the last dimension is 'n'.
991+
992+
for dd in dim_data:
993+
if dd['dist_type'] == 'n':
994+
dd['start'] = 0
995+
dd['stop'] = dd['size']
996+
997+
glb_shape = tuple(dd['size'] for dd in dim_data)
998+
glb_strides = strides_from_shape(glb_shape)
999+
1000+
ranges = [range(dd['start'], dd['stop']) for dd in dim_data[:-1]]
1001+
start_ranges = ranges + [[dim_data[-1]['start']]]
1002+
stop_ranges = ranges + [[dim_data[-1]['stop']]]
1003+
1004+
def flatten(idx):
1005+
return sum(a * b for (a, b) in zip(idx, glb_strides))
1006+
1007+
starts = map(flatten, product(*start_ranges))
1008+
stops = map(flatten, product(*stop_ranges))
1009+
1010+
intervals = zip(starts, stops)
1011+
return condense(intervals)
1012+
1013+
def _global_flat_indices_intersection(gfis0, gfis1):
1014+
intersections = filter(None, [tuple_intersection(a, b)
1015+
for (a, b) in product(gfis0, gfis1)])
1016+
return [i[:2] for i in intersections]

0 commit comments

Comments
 (0)