1717import sys
1818import time
1919from collections import defaultdict
20- from typing import Any , Dict , List , Optional , Set , Type
20+ from typing import Any , Dict , List , Optional , Set , Type , Tuple
2121
2222from .... import oscar as mo
2323from ....core import ChunkGraph , OperandType , enter_mode , ExecutionError
2424from ....core .context import get_context , set_context
25- from ....core .operand import Fetch , FetchShuffle , execute
25+ from ....core .operand import (
26+ Fetch ,
27+ FetchShuffle ,
28+ execute ,
29+ )
30+ from ....lib .aio import alru_cache
2631from ....metrics import Metrics
2732from ....optimization .physical import optimize
2833from ....typing import BandType , ChunkType
@@ -420,26 +425,56 @@ async def set_chunks_meta():
420425 # set result data size
421426 self .result .data_size = result_data_size
422427
423- async def _push_mapper_data (self , chunk_graph ):
424- # TODO: use task api to get reducer bands
425- reducer_idx_to_band = dict ()
426- if not reducer_idx_to_band :
427- return
428+ @classmethod
429+ @alru_cache (cache_exceptions = False )
430+ async def _gen_reducer_index_to_bands (
431+ cls , session_id : str , supervisor_address : str , task_id : str , map_reduce_id : int
432+ ) -> Dict [Tuple [int ], BandType ]:
433+ task_api = await TaskAPI .create (session_id , supervisor_address )
434+ map_reduce_info = await task_api .get_map_reduce_info (task_id , map_reduce_id )
435+ assert len (map_reduce_info .reducer_indexes ) == len (
436+ map_reduce_info .reducer_bands
437+ )
438+ return {
439+ reducer_index : band
440+ for reducer_index , band in zip (
441+ map_reduce_info .reducer_indexes , map_reduce_info .reducer_bands
442+ )
443+ }
444+
445+ async def _push_mapper_data (self ):
428446 storage_api_to_fetch_tasks = defaultdict (list )
429- for result_chunk in chunk_graph .result_chunks :
430- key = result_chunk .key
431- reducer_idx = key [1 ]
432- if isinstance (key , tuple ):
447+ skip = True
448+ for result_chunk in self ._chunk_graph .result_chunks :
449+ map_reduce_id = getattr (result_chunk .op , "extra_params" , dict ()).get (
450+ "analyzer_map_reduce_id"
451+ )
452+ if map_reduce_id is None :
453+ continue
454+ skip = False
455+ reducer_index_to_bands = await self ._gen_reducer_index_to_bands (
456+ self ._session_id ,
457+ self ._supervisor_address ,
458+ self .subtask .task_id ,
459+ map_reduce_id ,
460+ )
461+ for reducer_index , band in reducer_index_to_bands .items ():
433462 # mapper key is a tuple
434- address , band_name = reducer_idx_to_band [reducer_idx ]
435- storage_api = StorageAPI (address , self ._session_id , band_name )
463+ address , band_name = band
464+ storage_api = await StorageAPI .create (
465+ self ._session_id , address , band_name
466+ )
436467 fetch_task = storage_api .fetch .delay (
437- key , band_name = self ._band [1 ], remote_address = self ._band [0 ]
468+ (result_chunk .key , reducer_index ),
469+ band_name = self ._band [1 ],
470+ remote_address = self ._band [0 ],
438471 )
439472 storage_api_to_fetch_tasks [storage_api ].append (fetch_task )
473+ if skip :
474+ return
440475 batch_tasks = []
441476 for storage_api , tasks in storage_api_to_fetch_tasks .items ():
442- batch_tasks .append (asyncio . create_task ( storage_api .fetch .batch (* tasks ) ))
477+ batch_tasks .append (storage_api .fetch .batch (* tasks ))
443478 await asyncio .gather (* batch_tasks )
444479
445480 async def done (self ):
@@ -513,8 +548,6 @@ async def run(self):
513548 await self ._unpin_data (input_keys )
514549
515550 await self .done ()
516- # after done, we push mapper data to reducers in advance.
517- await self .ref ()._push_mapper_data .tell (chunk_graph )
518551 if self .result .status == SubtaskStatus .succeeded :
519552 cost_time_secs = (
520553 self .result .execution_end_time - self .result .execution_start_time
@@ -536,6 +569,9 @@ async def run(self):
536569 pass
537570 return self .result
538571
572+ async def post_run (self ):
573+ await self ._push_mapper_data ()
574+
539575 async def report_progress_periodically (self , interval = 0.5 , eps = 0.001 ):
540576 last_progress = self .result .progress
541577 while not self .result .status .is_done :
@@ -618,7 +654,7 @@ async def _init_context(self, session_id: str):
618654 await context .init ()
619655 set_context (context )
620656
621- async def run (self , subtask : Subtask ):
657+ async def run (self , subtask : Subtask , wait_post_run : bool = False ):
622658 logger .info ("Start to run subtask: %r on %s." , subtask , self .address )
623659
624660 assert subtask .session_id == self ._session_id
@@ -639,10 +675,18 @@ async def run(self, subtask: Subtask):
639675 try :
640676 result = yield self ._running_aio_task
641677 logger .info ("Finished subtask: %s" , subtask .subtask_id )
678+ # post run with actor tell which will not block
679+ if not wait_post_run :
680+ await self .ref ().post_run .tell (processor )
681+ else :
682+ await self .post_run (processor )
642683 raise mo .Return (result )
643684 finally :
644685 self ._processor = self ._running_aio_task = None
645686
687+ async def post_run (self , processor : SubtaskProcessor ):
688+ await processor .post_run ()
689+
646690 async def wait (self ):
647691 return self ._processor .is_done .wait ()
648692
0 commit comments