|
41 | 41 | sanitize_indices, |
42 | 42 | _start_stop_block, |
43 | 43 | tuple_intersection, |
44 | | - shapes_from_dim_data_per_rank) |
| 44 | + shapes_from_dim_data_per_rank, |
| 45 | + condense, |
| 46 | + strides_from_shape) |
45 | 47 |
|
46 | 48 |
|
47 | 49 | def _dedup_dim_dicts(dim_dicts): |
@@ -551,7 +553,7 @@ def from_maps(cls, context, maps, targets=None): |
551 | 553 | self = super(Distribution, cls).__new__(cls) |
552 | 554 | self.context = context |
553 | 555 | self.targets = sorted(targets or context.targets) |
554 | | - self.comm = self.context.make_subcomm(self.targets) |
| 556 | + self._comm = None |
555 | 557 | self.maps = maps |
556 | 558 | self.shape = tuple(m.size for m in self.maps) |
557 | 559 | self.ndim = len(self.maps) |
@@ -758,6 +760,12 @@ def __getitem__(self, idx): |
758 | 760 | def __len__(self): |
759 | 761 | return len(self.maps) |
760 | 762 |
|
| 763 | + @property |
| 764 | + def comm(self): |
| 765 | + if self._comm is None: |
| 766 | + self._comm = self.context.make_subcomm(self.targets) |
| 767 | + return self._comm |
| 768 | + |
761 | 769 | @property |
762 | 770 | def has_precise_index(self): |
763 | 771 | """ |
@@ -869,3 +877,140 @@ def view(self, new_dimsize=None): |
869 | 877 |
|
870 | 878 | def localshapes(self): |
871 | 879 | return shapes_from_dim_data_per_rank(self.get_dim_data_per_rank()) |
| 880 | + |
| 881 | + def comm_union(self, *dists): |
| 882 | + """ |
| 883 | + Make a communicator that includes the union of all targets in `dists`. |
| 884 | +
|
| 885 | + Parameters |
| 886 | + ---------- |
| 887 | + dists: sequence of distribution objects. |
| 888 | +
|
| 889 | + Returns |
| 890 | + ------- |
| 891 | + tuple |
| 892 | + First element is encompassing communicator proxy; second is a |
| 893 | + sequence of all targets in `dists`. |
| 894 | + |
| 895 | + """ |
| 896 | + dist_targets = [d.targets for d in dists] |
| 897 | + all_targets = sorted(reduce(set.union, dist_targets, set(self.targets))) |
| 898 | + return self.context.make_subcomm(all_targets), all_targets |
| 899 | + |
| 900 | + # ------------------------------------------------------------------------ |
| 901 | + # Redistribution |
| 902 | + # ------------------------------------------------------------------------ |
| 903 | + |
| 904 | + @staticmethod |
| 905 | + def _redist_intersection_same_shape(source_dimdata, dest_dimdata): |
| 906 | + |
| 907 | + intersections = [] |
| 908 | + for source_dimdict, dest_dimdict in zip(source_dimdata, dest_dimdata): |
| 909 | + |
| 910 | + if not (source_dimdict['dist_type'] == |
| 911 | + dest_dimdict['dist_type'] == 'b'): |
| 912 | + raise ValueError("Only 'b' dist_type supported") |
| 913 | + |
| 914 | + source_idxs = source_dimdict['start'], source_dimdict['stop'] |
| 915 | + dest_idxs = dest_dimdict['start'], dest_dimdict['stop'] |
| 916 | + |
| 917 | + intersections.append(tuple_intersection(source_idxs, dest_idxs)) |
| 918 | + |
| 919 | + return intersections |
| 920 | + |
| 921 | + @staticmethod |
| 922 | + def _redist_intersection_reshape(source_dimdata, dest_dimdata): |
| 923 | + source_flat = global_flat_indices(source_dimdata) |
| 924 | + dest_flat = global_flat_indices(dest_dimdata) |
| 925 | + return _global_flat_indices_intersection(source_flat, dest_flat) |
| 926 | + |
| 927 | + def get_redist_plan(self, other_dist): |
| 928 | + # Get all targets |
| 929 | + all_targets = sorted(set(self.targets + other_dist.targets)) |
| 930 | + union_rank_from_target = {t: r for (r, t) in enumerate(all_targets)} |
| 931 | + |
| 932 | + source_ranks = range(len(self.targets)) |
| 933 | + source_targets = self.targets |
| 934 | + union_rank_from_source_rank = {sr: union_rank_from_target[st] |
| 935 | + for (sr, st) in |
| 936 | + zip(source_ranks, source_targets)} |
| 937 | + |
| 938 | + dest_ranks = range(len(other_dist.targets)) |
| 939 | + dest_targets = other_dist.targets |
| 940 | + union_rank_from_dest_rank = {sr: union_rank_from_target[st] |
| 941 | + for (sr, st) in |
| 942 | + zip(dest_ranks, dest_targets)} |
| 943 | + |
| 944 | + source_ddpr = self.get_dim_data_per_rank() |
| 945 | + dest_ddpr = other_dist.get_dim_data_per_rank() |
| 946 | + source_dest_pairs = product(source_ddpr, dest_ddpr) |
| 947 | + |
| 948 | + if self.shape == other_dist.shape: |
| 949 | + _intersection = Distribution._redist_intersection_same_shape |
| 950 | + else: |
| 951 | + _intersection = Distribution._redist_intersection_reshape |
| 952 | + |
| 953 | + plan = [] |
| 954 | + for source_dd, dest_dd in source_dest_pairs: |
| 955 | + intersections = _intersection(source_dd, dest_dd) |
| 956 | + if intersections and all(i for i in intersections): |
| 957 | + source_coords = tuple(dd['proc_grid_rank'] for dd in source_dd) |
| 958 | + source_rank = self.rank_from_coords[source_coords] |
| 959 | + dest_coords = tuple(dd['proc_grid_rank'] for dd in dest_dd) |
| 960 | + dest_rank = other_dist.rank_from_coords[dest_coords] |
| 961 | + plan.append({ |
| 962 | + 'source_rank': union_rank_from_source_rank[source_rank], |
| 963 | + 'dest_rank': union_rank_from_dest_rank[dest_rank], |
| 964 | + 'indices': intersections, |
| 965 | + } |
| 966 | + ) |
| 967 | + |
| 968 | + return plan |
| 969 | + |
| 970 | + |
| 971 | +# ---------------------------------------------------------------------------- |
| 972 | +# Redistribution helper functions. |
| 973 | +# ---------------------------------------------------------------------------- |
| 974 | + |
| 975 | +def global_flat_indices(dim_data): |
| 976 | + """ |
| 977 | + Return a list of tuples of indices into the flattened global array. |
| 978 | +
|
| 979 | + Parameters |
| 980 | + ---------- |
| 981 | + dim_data: dimension dictionary. |
| 982 | +
|
| 983 | + Returns |
| 984 | + ------- |
| 985 | + list of 2-tuples of ints. |
| 986 | + Each tuple is a (start, stop) interval into the flattened global array. |
| 987 | + All selected ranges comprise the indices for this dim_data's sub-array. |
| 988 | +
|
| 989 | + """ |
| 990 | + # TODO: FIXME: can be optimized when the last dimension is 'n'. |
| 991 | + |
| 992 | + for dd in dim_data: |
| 993 | + if dd['dist_type'] == 'n': |
| 994 | + dd['start'] = 0 |
| 995 | + dd['stop'] = dd['size'] |
| 996 | + |
| 997 | + glb_shape = tuple(dd['size'] for dd in dim_data) |
| 998 | + glb_strides = strides_from_shape(glb_shape) |
| 999 | + |
| 1000 | + ranges = [range(dd['start'], dd['stop']) for dd in dim_data[:-1]] |
| 1001 | + start_ranges = ranges + [[dim_data[-1]['start']]] |
| 1002 | + stop_ranges = ranges + [[dim_data[-1]['stop']]] |
| 1003 | + |
| 1004 | + def flatten(idx): |
| 1005 | + return sum(a * b for (a, b) in zip(idx, glb_strides)) |
| 1006 | + |
| 1007 | + starts = map(flatten, product(*start_ranges)) |
| 1008 | + stops = map(flatten, product(*stop_ranges)) |
| 1009 | + |
| 1010 | + intervals = zip(starts, stops) |
| 1011 | + return condense(intervals) |
| 1012 | + |
| 1013 | +def _global_flat_indices_intersection(gfis0, gfis1): |
| 1014 | + intersections = filter(None, [tuple_intersection(a, b) |
| 1015 | + for (a, b) in product(gfis0, gfis1)]) |
| 1016 | + return [i[:2] for i in intersections] |
0 commit comments