Skip to content

Commit 93c12ff

Browse files
author
继盛
committed
Make assigner more balanced when encountering reassign_workers
1 parent 1d22198 commit 93c12ff

File tree

2 files changed

+74
-15
lines changed

2 files changed

+74
-15
lines changed

mars/services/task/analyzer/assigner.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
1717
from operator import itemgetter
18-
from typing import List, Dict, Set, Union
18+
from typing import List, Dict, Union
1919

2020
import numpy as np
2121

2222
from ....core import ChunkGraph, ChunkData
2323
from ....core.operand import Operand
24+
from ....lib.ordered_set import OrderedSet
2425
from ....resource import Resource
2526
from ....typing import BandType
2627
from ....utils import implements
@@ -77,8 +78,9 @@ def __init__(
7778
band_resource: Dict[BandType, Resource],
7879
):
7980
super().__init__(chunk_graph, start_ops, band_resource)
80-
self._undirected_chunk_graph = None
81-
self._op_keys: Set[str] = {start_op.key for start_op in start_ops}
81+
self._op_keys: OrderedSet[str] = OrderedSet(
82+
[start_op.key for start_op in start_ops]
83+
)
8284

8385
def _calc_band_assign_limits(
8486
self, initial_count: int, occupied: Dict[BandType, int]
@@ -124,13 +126,15 @@ def _calc_band_assign_limits(
124126
pos = (pos + 1) % len(counts)
125127
return dict(zip(bands, counts))
126128

129+
@classmethod
127130
def _assign_by_bfs(
128-
self,
131+
cls,
132+
undirected_chunk_graph: ChunkGraph,
129133
start: ChunkData,
130134
band: BandType,
131135
initial_sizes: Dict[BandType, int],
132136
spread_limits: Dict[BandType, float],
133-
key_to_assign: Set[str],
137+
key_to_assign: OrderedSet[str],
134138
assigned_record: Dict[str, Union[str, BandType]],
135139
):
136140
"""
@@ -140,19 +144,15 @@ def _assign_by_bfs(
140144
if initial_sizes[band] <= 0:
141145
return
142146

143-
graph = self._chunk_graph
144-
if self._undirected_chunk_graph is None:
145-
self._undirected_chunk_graph = graph.build_undirected()
146-
undirected_chunk_graph = self._undirected_chunk_graph
147-
148147
assigned = 0
149148
spread_range = 0
150149
for chunk in undirected_chunk_graph.bfs(start=start, visit_predicate="all"):
151150
op_key = chunk.op.key
152151
if op_key in assigned_record:
153152
continue
154153
spread_range += 1
155-
# `op_key` may not be in `key_to_assign`, but we need to record it to avoid iterate the node repeatedly.
154+
# `op_key` may not be in `key_to_assign`,
155+
# but we need to record it to avoid iterate the node repeatedly.
156156
assigned_record[op_key] = band
157157
if op_key not in key_to_assign:
158158
continue
@@ -161,8 +161,23 @@ def _assign_by_bfs(
161161
break
162162
initial_sizes[band] -= assigned
163163

164+
def _build_undirected_chunk_graph(
165+
self, chunk_to_assign: List[ChunkData]
166+
) -> ChunkGraph:
167+
chunk_graph = ChunkGraph()
168+
self._chunk_graph.copyto(chunk_graph)
169+
# remove edges for all chunk_to_assign which may contain chunks
170+
# that need be reassigned
171+
for chunk in chunk_to_assign:
172+
if chunk_graph.count_predecessors(chunk) > 0:
173+
for pred in list(chunk_graph.predecessors(chunk)):
174+
chunk_graph.remove_edge(pred, chunk)
175+
return chunk_graph.build_undirected()
176+
164177
@implements(AbstractGraphAssigner.assign)
165-
def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType]:
178+
def assign(
179+
self, cur_assigns: Dict[str, BandType] = None
180+
) -> Dict[ChunkData, BandType]:
166181
graph = self._chunk_graph
167182
assign_result = dict()
168183
cur_assigns = cur_assigns or dict()
@@ -173,7 +188,7 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
173188
for chunk in graph:
174189
op_key_to_chunks[chunk.op.key].append(chunk)
175190

176-
op_keys = set(self._op_keys)
191+
op_keys = OrderedSet(self._op_keys)
177192
chunk_to_assign = [
178193
op_key_to_chunks[op_key][0]
179194
for op_key in op_keys
@@ -183,6 +198,9 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
183198
for band in cur_assigns.values():
184199
assigned_counts[band] += 1
185200

201+
# build undirected graph
202+
undirected_chunk_graph = self._build_undirected_chunk_graph(chunk_to_assign)
203+
186204
# calculate the number of chunks to be assigned to each band
187205
# given number of bands and existing assignments
188206
band_quotas = self._calc_band_assign_limits(
@@ -195,14 +213,20 @@ def assign(self, cur_assigns: Dict[str, str] = None) -> Dict[ChunkData, BandType
195213
spread_ranges = defaultdict(lambda: average_spread_range)
196214
# assign from other chunks to be assigned
197215
# TODO: sort by what?
198-
sorted_candidates = [v for v in chunk_to_assign]
216+
sorted_candidates = chunk_to_assign.copy()
199217
while max(band_quotas.values()):
200218
band = max(band_quotas, key=lambda k: band_quotas[k])
201219
cur = sorted_candidates.pop()
202220
while cur.op.key in cur_assigns:
203221
cur = sorted_candidates.pop()
204222
self._assign_by_bfs(
205-
cur, band, band_quotas, spread_ranges, op_keys, cur_assigns
223+
undirected_chunk_graph,
224+
cur,
225+
band,
226+
band_quotas,
227+
spread_ranges,
228+
op_keys,
229+
cur_assigns,
206230
)
207231

208232
key_to_assign = {n.op.key for n in chunk_to_assign} | initial_assigned_op_keys

mars/services/task/analyzer/tests/test_assigner.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import pandas as pd
1617

18+
from ..... import dataframe as md
1719
from .....config import Config
1820
from .....core import ChunkGraph
21+
from .....core.graph.builder.utils import build_graph
22+
from .....core.operand import OperandStage
1923
from .....tensor.random import TensorRand
2024
from .....tensor.arithmetic import TensorAdd
2125
from .....tensor.fetch import TensorFetch
@@ -71,3 +75,34 @@ def test_assigner_with_fetch_inputs():
7175
for inp in input_chunks:
7276
if not isinstance(inp.op, TensorFetch):
7377
assert subtask.expect_band == key_to_assign[inp.key]
78+
79+
80+
def test_shuffle_assign():
81+
band_num = 8
82+
all_bands = [(f"address_{i}", "numa-0") for i in range(band_num)]
83+
84+
pdf = pd.DataFrame(np.random.rand(32, 4))
85+
df = md.DataFrame(pdf, chunk_size=4)
86+
r = df.groupby(0).sum(method="shuffle")
87+
chunk_graph = build_graph([r], tile=True)
88+
89+
band_resource = dict((band, Resource(num_cpus=1)) for band in all_bands)
90+
91+
reassign_worker_ops = [
92+
chunk.op for chunk in chunk_graph if chunk.op.reassign_worker
93+
]
94+
start_ops = list(GraphAnalyzer._iter_start_ops(chunk_graph))
95+
to_assign_ops = start_ops + reassign_worker_ops
96+
97+
assigner = GraphAssigner(chunk_graph, to_assign_ops, band_resource)
98+
assigns = assigner.assign()
99+
assert len(assigns) == 16
100+
init_assigns = set()
101+
reducer_assigns = set()
102+
for chunk, assign in assigns.items():
103+
if chunk.op.stage == OperandStage.reduce:
104+
reducer_assigns.add(assign)
105+
else:
106+
init_assigns.add(assign)
107+
# init and reducers are assigned on all bands
108+
assert len(init_assigns) == len(reducer_assigns) == 8

0 commit comments

Comments
 (0)