1515from abc import ABC , abstractmethod
1616from collections import defaultdict
1717from operator import itemgetter
18- from typing import List , Dict , Set , Union
18+ from typing import List , Dict , Union
1919
2020import numpy as np
2121
2222from ....core import ChunkGraph , ChunkData
2323from ....core .operand import Operand
24+ from ....lib .ordered_set import OrderedSet
2425from ....resource import Resource
2526from ....typing import BandType
2627from ....utils import implements
@@ -77,8 +78,9 @@ def __init__(
7778 band_resource : Dict [BandType , Resource ],
7879 ):
7980 super ().__init__ (chunk_graph , start_ops , band_resource )
80- self ._undirected_chunk_graph = None
81- self ._op_keys : Set [str ] = {start_op .key for start_op in start_ops }
81+ self ._op_keys : OrderedSet [str ] = OrderedSet (
82+ [start_op .key for start_op in start_ops ]
83+ )
8284
8385 def _calc_band_assign_limits (
8486 self , initial_count : int , occupied : Dict [BandType , int ]
@@ -124,13 +126,15 @@ def _calc_band_assign_limits(
124126 pos = (pos + 1 ) % len (counts )
125127 return dict (zip (bands , counts ))
126128
129+ @classmethod
127130 def _assign_by_bfs (
128- self ,
131+ cls ,
132+ undirected_chunk_graph : ChunkGraph ,
129133 start : ChunkData ,
130134 band : BandType ,
131135 initial_sizes : Dict [BandType , int ],
132136 spread_limits : Dict [BandType , float ],
133- key_to_assign : Set [str ],
137+ key_to_assign : OrderedSet [str ],
134138 assigned_record : Dict [str , Union [str , BandType ]],
135139 ):
136140 """
@@ -140,19 +144,15 @@ def _assign_by_bfs(
140144 if initial_sizes [band ] <= 0 :
141145 return
142146
143- graph = self ._chunk_graph
144- if self ._undirected_chunk_graph is None :
145- self ._undirected_chunk_graph = graph .build_undirected ()
146- undirected_chunk_graph = self ._undirected_chunk_graph
147-
148147 assigned = 0
149148 spread_range = 0
150149 for chunk in undirected_chunk_graph .bfs (start = start , visit_predicate = "all" ):
151150 op_key = chunk .op .key
152151 if op_key in assigned_record :
153152 continue
154153 spread_range += 1
155- # `op_key` may not be in `key_to_assign`, but we need to record it to avoid iterate the node repeatedly.
154+ # `op_key` may not be in `key_to_assign`,
155+ # but we need to record it to avoid iterate the node repeatedly.
156156 assigned_record [op_key ] = band
157157 if op_key not in key_to_assign :
158158 continue
@@ -161,8 +161,22 @@ def _assign_by_bfs(
161161 break
162162 initial_sizes [band ] -= assigned
163163
164+ def _build_undirected_chunk_graph (
165+ self , chunk_to_assign : List [ChunkData ]
166+ ) -> ChunkGraph :
167+ chunk_graph = self ._chunk_graph .copy ()
168+ # remove edges for all chunk_to_assign which may contain chunks
169+ # that need be reassigned
170+ for chunk in chunk_to_assign :
171+ if chunk_graph .count_predecessors (chunk ) > 0 :
172+ for pred in list (chunk_graph .predecessors (chunk )):
173+ chunk_graph .remove_edge (pred , chunk )
174+ return chunk_graph .build_undirected ()
175+
164176 @implements (AbstractGraphAssigner .assign )
165- def assign (self , cur_assigns : Dict [str , str ] = None ) -> Dict [ChunkData , BandType ]:
177+ def assign (
178+ self , cur_assigns : Dict [str , BandType ] = None
179+ ) -> Dict [ChunkData , BandType ]:
166180 graph = self ._chunk_graph
167181 assign_result = dict ()
168182 cur_assigns = cur_assigns or dict ()
@@ -173,7 +187,7 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
173187 for chunk in graph :
174188 op_key_to_chunks [chunk .op .key ].append (chunk )
175189
176- op_keys = set (self ._op_keys )
190+ op_keys = OrderedSet (self ._op_keys )
177191 chunk_to_assign = [
178192 op_key_to_chunks [op_key ][0 ]
179193 for op_key in op_keys
@@ -183,6 +197,9 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
183197 for band in cur_assigns .values ():
184198 assigned_counts [band ] += 1
185199
200+ # build undirected graph
201+ undirected_chunk_graph = self ._build_undirected_chunk_graph (chunk_to_assign )
202+
186203 # calculate the number of chunks to be assigned to each band
187204 # given number of bands and existing assignments
188205 band_quotas = self ._calc_band_assign_limits (
@@ -195,14 +212,20 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
195212 spread_ranges = defaultdict (lambda : average_spread_range )
196213 # assign from other chunks to be assigned
197214 # TODO: sort by what?
198- sorted_candidates = [ v for v in chunk_to_assign ]
215+ sorted_candidates = chunk_to_assign . copy ()
199216 while max (band_quotas .values ()):
200217 band = max (band_quotas , key = lambda k : band_quotas [k ])
201218 cur = sorted_candidates .pop ()
202219 while cur .op .key in cur_assigns :
203220 cur = sorted_candidates .pop ()
204221 self ._assign_by_bfs (
205- cur , band , band_quotas , spread_ranges , op_keys , cur_assigns
222+ undirected_chunk_graph ,
223+ cur ,
224+ band ,
225+ band_quotas ,
226+ spread_ranges ,
227+ op_keys ,
228+ cur_assigns ,
206229 )
207230
208231 key_to_assign = {n .op .key for n in chunk_to_assign } | initial_assigned_op_keys
0 commit comments