1515from abc import ABC , abstractmethod
1616from collections import defaultdict
1717from operator import itemgetter
18- from typing import List , Dict , Set , Union
18+ from typing import List , Dict , Union
1919
2020import numpy as np
2121
2222from ....core import ChunkGraph , ChunkData
2323from ....core .operand import Operand
24+ from ....lib .ordered_set import OrderedSet
2425from ....resource import Resource
2526from ....typing import BandType
2627from ....utils import implements
@@ -77,8 +78,9 @@ def __init__(
7778 band_resource : Dict [BandType , Resource ],
7879 ):
7980 super ().__init__ (chunk_graph , start_ops , band_resource )
80- self ._undirected_chunk_graph = None
81- self ._op_keys : Set [str ] = {start_op .key for start_op in start_ops }
81+ self ._op_keys : OrderedSet [str ] = OrderedSet (
82+ [start_op .key for start_op in start_ops ]
83+ )
8284
8385 def _calc_band_assign_limits (
8486 self , initial_count : int , occupied : Dict [BandType , int ]
@@ -124,13 +126,15 @@ def _calc_band_assign_limits(
124126 pos = (pos + 1 ) % len (counts )
125127 return dict (zip (bands , counts ))
126128
129+ @classmethod
127130 def _assign_by_bfs (
128- self ,
131+ cls ,
132+ undirected_chunk_graph : ChunkGraph ,
129133 start : ChunkData ,
130134 band : BandType ,
131135 initial_sizes : Dict [BandType , int ],
132136 spread_limits : Dict [BandType , float ],
133- key_to_assign : Set [str ],
137+ key_to_assign : OrderedSet [str ],
134138 assigned_record : Dict [str , Union [str , BandType ]],
135139 ):
136140 """
@@ -140,19 +144,15 @@ def _assign_by_bfs(
140144 if initial_sizes [band ] <= 0 :
141145 return
142146
143- graph = self ._chunk_graph
144- if self ._undirected_chunk_graph is None :
145- self ._undirected_chunk_graph = graph .build_undirected ()
146- undirected_chunk_graph = self ._undirected_chunk_graph
147-
148147 assigned = 0
149148 spread_range = 0
150149 for chunk in undirected_chunk_graph .bfs (start = start , visit_predicate = "all" ):
151150 op_key = chunk .op .key
152151 if op_key in assigned_record :
153152 continue
154153 spread_range += 1
155- # `op_key` may not be in `key_to_assign`, but we need to record it to avoid iterate the node repeatedly.
154+ # `op_key` may not be in `key_to_assign`,
155+ # but we need to record it to avoid iterate the node repeatedly.
156156 assigned_record [op_key ] = band
157157 if op_key not in key_to_assign :
158158 continue
@@ -161,8 +161,23 @@ def _assign_by_bfs(
161161 break
162162 initial_sizes [band ] -= assigned
163163
164+ def _build_undirected_chunk_graph (
165+ self , chunk_to_assign : List [ChunkData ]
166+ ) -> ChunkGraph :
167+ chunk_graph = ChunkGraph ()
168+ self ._chunk_graph .copyto (chunk_graph )
169+ # remove edges for all chunk_to_assign which may contain chunks
170+ # that need be reassigned
171+ for chunk in chunk_to_assign :
172+ if chunk_graph .count_predecessors (chunk ) > 0 :
173+ for pred in list (chunk_graph .predecessors (chunk )):
174+ chunk_graph .remove_edge (pred , chunk )
175+ return chunk_graph .build_undirected ()
176+
164177 @implements (AbstractGraphAssigner .assign )
165- def assign (self , cur_assigns : Dict [str , str ] = None ) -> Dict [ChunkData , BandType ]:
178+ def assign (
179+ self , cur_assigns : Dict [str , BandType ] = None
180+ ) -> Dict [ChunkData , BandType ]:
166181 graph = self ._chunk_graph
167182 assign_result = dict ()
168183 cur_assigns = cur_assigns or dict ()
@@ -173,7 +188,7 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
173188 for chunk in graph :
174189 op_key_to_chunks [chunk .op .key ].append (chunk )
175190
176- op_keys = set (self ._op_keys )
191+ op_keys = OrderedSet (self ._op_keys )
177192 chunk_to_assign = [
178193 op_key_to_chunks [op_key ][0 ]
179194 for op_key in op_keys
@@ -183,6 +198,9 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
183198 for band in cur_assigns .values ():
184199 assigned_counts [band ] += 1
185200
201+ # build undirected graph
202+ undirected_chunk_graph = self ._build_undirected_chunk_graph (chunk_to_assign )
203+
186204 # calculate the number of chunks to be assigned to each band
187205 # given number of bands and existing assignments
188206 band_quotas = self ._calc_band_assign_limits (
@@ -195,14 +213,20 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
195213 spread_ranges = defaultdict (lambda : average_spread_range )
196214 # assign from other chunks to be assigned
197215 # TODO: sort by what?
198- sorted_candidates = [ v for v in chunk_to_assign ]
216+ sorted_candidates = chunk_to_assign . copy ()
199217 while max (band_quotas .values ()):
200218 band = max (band_quotas , key = lambda k : band_quotas [k ])
201219 cur = sorted_candidates .pop ()
202220 while cur .op .key in cur_assigns :
203221 cur = sorted_candidates .pop ()
204222 self ._assign_by_bfs (
205- cur , band , band_quotas , spread_ranges , op_keys , cur_assigns
223+ undirected_chunk_graph ,
224+ cur ,
225+ band ,
226+ band_quotas ,
227+ spread_ranges ,
228+ op_keys ,
229+ cur_assigns ,
206230 )
207231
208232 key_to_assign = {n .op .key for n in chunk_to_assign } | initial_assigned_op_keys
0 commit comments