Skip to content

Commit c52c107

Browse files
author
xuye.qin
committed
Make push mapper data work
1 parent b767644 commit c52c107

File tree

13 files changed

+231
-31
lines changed

13 files changed

+231
-31
lines changed

mars/services/subtask/worker/processor.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
import sys
1818
import time
1919
from collections import defaultdict
20-
from typing import Any, Dict, List, Optional, Set, Type
20+
from typing import Any, Dict, List, Optional, Set, Type, Tuple
2121

2222
from .... import oscar as mo
2323
from ....core import ChunkGraph, OperandType, enter_mode, ExecutionError
2424
from ....core.context import get_context, set_context
25-
from ....core.operand import Fetch, FetchShuffle, execute
25+
from ....core.operand import (
26+
Fetch,
27+
FetchShuffle,
28+
execute,
29+
)
30+
from ....lib.aio import alru_cache
2631
from ....metrics import Metrics
2732
from ....optimization.physical import optimize
2833
from ....typing import BandType, ChunkType
@@ -420,26 +425,56 @@ async def set_chunks_meta():
420425
# set result data size
421426
self.result.data_size = result_data_size
422427

423-
async def _push_mapper_data(self, chunk_graph):
424-
# TODO: use task api to get reducer bands
425-
reducer_idx_to_band = dict()
426-
if not reducer_idx_to_band:
427-
return
428+
@classmethod
429+
@alru_cache(cache_exceptions=False)
430+
async def _gen_reducer_index_to_bands(
431+
cls, session_id: str, supervisor_address: str, task_id: str, map_reduce_id: int
432+
) -> Dict[Tuple[int], BandType]:
433+
task_api = await TaskAPI.create(session_id, supervisor_address)
434+
map_reduce_info = await task_api.get_map_reduce_info(task_id, map_reduce_id)
435+
assert len(map_reduce_info.reducer_indexes) == len(
436+
map_reduce_info.reducer_bands
437+
)
438+
return {
439+
reducer_index: band
440+
for reducer_index, band in zip(
441+
map_reduce_info.reducer_indexes, map_reduce_info.reducer_bands
442+
)
443+
}
444+
445+
async def _push_mapper_data(self):
428446
storage_api_to_fetch_tasks = defaultdict(list)
429-
for result_chunk in chunk_graph.result_chunks:
430-
key = result_chunk.key
431-
reducer_idx = key[1]
432-
if isinstance(key, tuple):
447+
skip = True
448+
for result_chunk in self._chunk_graph.result_chunks:
449+
map_reduce_id = getattr(result_chunk.op, "extra_params", dict()).get(
450+
"analyzer_map_reduce_id"
451+
)
452+
if map_reduce_id is None:
453+
continue
454+
skip = False
455+
reducer_index_to_bands = await self._gen_reducer_index_to_bands(
456+
self._session_id,
457+
self._supervisor_address,
458+
self.subtask.task_id,
459+
map_reduce_id,
460+
)
461+
for reducer_index, band in reducer_index_to_bands.items():
433462
# mapper key is a tuple
434-
address, band_name = reducer_idx_to_band[reducer_idx]
435-
storage_api = StorageAPI(address, self._session_id, band_name)
463+
address, band_name = band
464+
storage_api = await StorageAPI.create(
465+
self._session_id, address, band_name
466+
)
436467
fetch_task = storage_api.fetch.delay(
437-
key, band_name=self._band[1], remote_address=self._band[0]
468+
(result_chunk.key, reducer_index),
469+
band_name=self._band[1],
470+
remote_address=self._band[0],
438471
)
439472
storage_api_to_fetch_tasks[storage_api].append(fetch_task)
473+
if skip:
474+
return
440475
batch_tasks = []
441476
for storage_api, tasks in storage_api_to_fetch_tasks.items():
442-
batch_tasks.append(asyncio.create_task(storage_api.fetch.batch(*tasks)))
477+
batch_tasks.append(storage_api.fetch.batch(*tasks))
443478
await asyncio.gather(*batch_tasks)
444479

445480
async def done(self):
@@ -513,8 +548,6 @@ async def run(self):
513548
await self._unpin_data(input_keys)
514549

515550
await self.done()
516-
# after done, we push mapper data to reducers in advance.
517-
await self.ref()._push_mapper_data.tell(chunk_graph)
518551
if self.result.status == SubtaskStatus.succeeded:
519552
cost_time_secs = (
520553
self.result.execution_end_time - self.result.execution_start_time
@@ -536,6 +569,9 @@ async def run(self):
536569
pass
537570
return self.result
538571

572+
async def post_run(self):
573+
await self._push_mapper_data()
574+
539575
async def report_progress_periodically(self, interval=0.5, eps=0.001):
540576
last_progress = self.result.progress
541577
while not self.result.status.is_done:
@@ -618,7 +654,7 @@ async def _init_context(self, session_id: str):
618654
await context.init()
619655
set_context(context)
620656

621-
async def run(self, subtask: Subtask):
657+
async def run(self, subtask: Subtask, wait_post_run: bool = False):
622658
logger.info(
623659
"Start to run subtask: %r on %s. chunk graph contains %s",
624660
subtask,
@@ -644,10 +680,18 @@ async def run(self, subtask: Subtask):
644680
try:
645681
result = yield self._running_aio_task
646682
logger.info("Finished subtask: %s", subtask.subtask_id)
683+
# post run with actor tell which will not block
684+
if not wait_post_run:
685+
await self.ref().post_run.tell(processor)
686+
else:
687+
await self.post_run(processor)
647688
raise mo.Return(result)
648689
finally:
649690
self._processor = self._running_aio_task = None
650691

692+
async def post_run(self, processor: SubtaskProcessor):
693+
await processor.post_run()
694+
651695
async def wait(self):
652696
return self._processor.is_done.wait()
653697

mars/services/subtask/worker/runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def _get_supervisor_address(self, session_id: str):
8989
[address] = await self._cluster_api.get_supervisors_by_keys([session_id])
9090
return address
9191

92-
async def run_subtask(self, subtask: Subtask):
92+
async def run_subtask(self, subtask: Subtask, wait_post_run: bool = False):
9393
if self._running_processor is not None: # pragma: no cover
9494
running_subtask_id = await self._running_processor.get_running_subtask_id()
9595
# current subtask is still running
@@ -122,7 +122,9 @@ async def run_subtask(self, subtask: Subtask):
122122
processor = self._session_id_to_processors[session_id]
123123
try:
124124
self._running_processor = self._last_processor = processor
125-
result = yield self._running_processor.run(subtask)
125+
result = yield self._running_processor.run(
126+
subtask, wait_post_run=wait_post_run
127+
)
126128
finally:
127129
self._running_processor = None
128130
raise mo.Return(result)

mars/services/subtask/worker/tests/test_subtask.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@
1818
import time
1919

2020
import numpy as np
21+
import pandas as pd
2122
import pytest
2223

2324
from ..... import oscar as mo
25+
from ..... import dataframe as md
2426
from ..... import tensor as mt
2527
from ..... import remote as mr
26-
from .....core import ExecutionError
28+
from .....core import ExecutionError, ChunkGraph
2729
from .....core.context import get_context
2830
from .....core.graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
31+
from .....core.operand import OperandStage
2932
from .....resource import Resource
3033
from .....utils import Timer
3134
from ....cluster import MockClusterAPI
@@ -34,7 +37,7 @@
3437
from ....scheduling import MockSchedulingAPI
3538
from ....session import MockSessionAPI
3639
from ....storage import MockStorageAPI
37-
from ....task import new_task_id
40+
from ....task import new_task_id, MapReduceInfo
3841
from ....task.supervisor.manager import TaskManagerActor, TaskConfigurationActor
3942
from ....mutable import MockMutableAPI
4043
from ... import Subtask, SubtaskStatus, SubtaskResult
@@ -46,6 +49,13 @@ class FakeTaskManager(TaskManagerActor):
4649
def set_subtask_result(self, subtask_result: SubtaskResult):
4750
return
4851

52+
def get_map_reduce_info(self, task_id: str, map_reduce_id: int) -> MapReduceInfo:
53+
return MapReduceInfo(
54+
map_reduce_id=0,
55+
reducer_indexes=[(0, 0)],
56+
reducer_bands=[(self.address, "numa-0")],
57+
)
58+
4959

5060
@pytest.fixture
5161
async def actor_pool():
@@ -142,6 +152,39 @@ async def test_subtask_success(actor_pool):
142152
assert await subtask_runner.is_runner_free() is True
143153

144154

155+
@pytest.mark.asyncio
156+
async def test_shuffle_subtask(actor_pool):
157+
pool, session_id, meta_api, storage_api, manager = actor_pool
158+
159+
pdf = pd.DataFrame({"f1": ["a", "b", "a"], "f2": [1, 2, 3]})
160+
df = md.DataFrame(pdf)
161+
result = df.groupby("f1").sum(method="shuffle")
162+
163+
graph = TileableGraph([result.data])
164+
next(TileableGraphBuilder(graph).build())
165+
chunk_graph = next(ChunkGraphBuilder(graph, fuse_enabled=False).build())
166+
result_chunks = []
167+
new_chunk_graph = ChunkGraph(result_chunks)
168+
chunk_graph_iter = chunk_graph.topological_iter()
169+
curr = None
170+
for _ in range(3):
171+
prev = curr
172+
curr = next(chunk_graph_iter)
173+
new_chunk_graph.add_node(curr)
174+
if prev is not None:
175+
new_chunk_graph.add_edge(prev, curr)
176+
assert curr.op.stage == OperandStage.map
177+
curr.op.extra_params = {"analyzer_map_reduce_id": 0}
178+
result_chunks.append(curr)
179+
subtask = Subtask(new_task_id(), session_id, new_task_id(), new_chunk_graph)
180+
subtask_runner: SubtaskRunnerRef = await mo.actor_ref(
181+
SubtaskRunnerActor.gen_uid("numa-0", 0), address=pool.external_address
182+
)
183+
await subtask_runner.run_subtask(subtask, wait_post_run=True)
184+
result = await subtask_runner.get_subtask_result()
185+
assert result.status == SubtaskStatus.succeeded
186+
187+
145188
@pytest.mark.asyncio
146189
async def test_subtask_failure(actor_pool):
147190
pool, session_id, meta_api, storage_api, manager = actor_pool

mars/services/task/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
from .api import AbstractTaskAPI, TaskAPI, WebTaskAPI
1616
from .config import task_options
17-
from .core import Task, TaskStatus, TaskResult, new_task_id
17+
from .core import Task, TaskStatus, TaskResult, new_task_id, MapReduceInfo
1818
from .errors import TaskNotExist

mars/services/task/analyzer/analyzer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
LogicKeyGenerator,
2626
MapReduceOperand,
2727
OperandStage,
28+
ShuffleProxy,
2829
)
30+
from ....lib.ordered_set import OrderedSet
2931
from ....resource import Resource
3032
from ....typing import BandType, OperandType
3133
from ....utils import build_fetch, tokenize
3234
from ...subtask import SubtaskGraph, Subtask
33-
from ..core import Task, new_task_id
35+
from ..core import Task, new_task_id, MapReduceInfo
3436
from .assigner import AbstractGraphAssigner, GraphAssigner
3537
from .fusion import Coloring
3638

@@ -50,6 +52,8 @@ def need_reassign_worker(op: OperandType) -> bool:
5052

5153

5254
class GraphAnalyzer:
55+
_map_reduce_id = itertools.count()
56+
5357
def __init__(
5458
self,
5559
chunk_graph: ChunkGraph,
@@ -59,6 +63,7 @@ def __init__(
5963
chunk_to_subtasks: Dict[ChunkType, Subtask],
6064
graph_assigner_cls: Type[AbstractGraphAssigner] = None,
6165
stage_id: str = None,
66+
map_reduce_id_to_infos: Dict[int, MapReduceInfo] = None,
6267
):
6368
self._chunk_graph = chunk_graph
6469
self._band_resource = band_resource
@@ -68,12 +73,17 @@ def __init__(
6873
self._fuse_enabled = task.fuse_enabled
6974
self._extra_config = task.extra_config
7075
self._chunk_to_subtasks = chunk_to_subtasks
76+
self._map_reduce_id_to_infos = map_reduce_id_to_infos
7177
if graph_assigner_cls is None:
7278
graph_assigner_cls = GraphAssigner
7379
self._graph_assigner_cls = graph_assigner_cls
7480
self._chunk_to_copied = dict()
7581
self._logic_key_generator = LogicKeyGenerator()
7682

83+
@classmethod
84+
def next_map_reduce_id(cls) -> int:
85+
return next(cls._map_reduce_id)
86+
7787
@classmethod
7888
def _iter_start_ops(cls, chunk_graph: ChunkGraph):
7989
visited = set()
@@ -300,6 +310,38 @@ def _gen_logic_key(self, chunks: List[ChunkType]):
300310
*[self._logic_key_generator.get_logic_key(chunk.op) for chunk in chunks]
301311
)
302312

313+
def _gen_map_reduce_info(
314+
self, chunk: ChunkType, assign_results: Dict[ChunkType, BandType]
315+
):
316+
reducer_ops = OrderedSet(
317+
[
318+
c.op
319+
for c in self._chunk_graph.successors(chunk)
320+
if c.op.stage == OperandStage.reduce
321+
]
322+
)
323+
map_chunks = [
324+
c
325+
for c in self._chunk_graph.predecessors(chunk)
326+
if c.op.stage == OperandStage.map
327+
]
328+
map_reduce_id = self.next_map_reduce_id()
329+
for map_chunk in map_chunks:
330+
# record analyzer map reduce id for mapper op
331+
# copied chunk exists because map chunk must have
332+
# been processed before shuffle proxy
333+
copied_map_chunk_op = self._chunk_to_copied[map_chunk].op
334+
if not hasattr(copied_map_chunk_op, "extra_params"):
335+
copied_map_chunk_op.extra_params = dict()
336+
copied_map_chunk_op.extra_params["analyzer_map_reduce_id"] = map_reduce_id
337+
reducer_bands = [assign_results[r.outputs[0]] for r in reducer_ops]
338+
map_reduce_info = MapReduceInfo(
339+
map_reduce_id=map_reduce_id,
340+
reducer_indexes=[reducer_op.reducer_index for reducer_op in reducer_ops],
341+
reducer_bands=reducer_bands,
342+
)
343+
self._map_reduce_id_to_infos[map_reduce_id] = map_reduce_info
344+
303345
@enter_mode(build=True)
304346
def gen_subtask_graph(
305347
self, op_to_bands: Dict[str, BandType] = None
@@ -420,6 +462,10 @@ def gen_subtask_graph(
420462

421463
for c in same_color_chunks:
422464
chunk_to_subtask[c] = subtask
465+
if self._map_reduce_id_to_infos is not None and isinstance(
466+
chunk.op, ShuffleProxy
467+
):
468+
self._gen_map_reduce_info(chunk, chunk_to_bands)
423469
visited.update(same_color_chunks)
424470

425471
for subtasks in logic_key_to_subtasks.values():

mars/services/task/api/oscar.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ....core import Tileable
1919
from ....lib.aio import alru_cache
2020
from ...subtask import SubtaskResult
21-
from ..core import TileableGraph, TaskResult
21+
from ..core import TileableGraph, TaskResult, MapReduceInfo
2222
from ..supervisor.manager import TaskManagerActor
2323
from .core import AbstractTaskAPI
2424

@@ -104,3 +104,8 @@ async def set_subtask_result(self, subtask_result: SubtaskResult):
104104

105105
async def get_last_idle_time(self) -> Union[float, None]:
106106
return await self._task_manager_ref.get_last_idle_time()
107+
108+
async def get_map_reduce_info(
109+
self, task_id: str, map_reduce_id: int
110+
) -> MapReduceInfo:
111+
return await self._task_manager_ref.get_map_reduce_info(task_id, map_reduce_id)

0 commit comments

Comments
 (0)