2626from ....cluster import ClusterAPI
2727from ....core import ActorCallback
2828from ....subtask import Subtask , SubtaskAPI , SubtaskResult , SubtaskStatus
29+ from ...utils import ResultCache
2930from ..queues import SubtaskPrepareQueueActor , SubtaskExecutionQueueActor
3031from ..quota import QuotaActor
3132from ..slotmanager import SlotManagerActor
@@ -49,18 +50,21 @@ class SubtaskExecutionActor(mo.Actor):
4950
5051 _subtask_api : SubtaskAPI
5152 _subtask_preparer : SubtaskPreparer
53+ _subtask_result_cache : ResultCache [SubtaskResult ]
5254
5355 def __init__ (
5456 self ,
5557 subtask_max_retries : int = None ,
5658 enable_kill_slot : bool = True ,
5759 ):
5860 self ._pred_key_mapping_dag = DAG ()
59- self ._subtask_caches = dict ()
60- self ._subtask_executions = dict ()
6161 self ._prepare_queue_ref = None
6262 self ._execution_queue_ref = None
6363
64+ self ._subtask_caches = dict ()
65+ self ._subtask_executions = dict ()
66+ self ._subtask_result_cache = ResultCache ()
67+
6468 self ._subtask_max_retries = subtask_max_retries or DEFAULT_SUBTASK_MAX_RETRIES
6569 self ._enable_kill_slot = enable_kill_slot
6670
@@ -222,6 +226,7 @@ async def submit_subtasks(
222226 priorities : List [Tuple ],
223227 supervisor_address : str ,
224228 band_name : str ,
229+ reschedule : bool = False ,
225230 ):
226231 assert len (subtasks ) == len (priorities )
227232 logger .debug ("%d subtasks submitted to SubtaskExecutionActor" , len (subtasks ))
@@ -230,12 +235,15 @@ async def submit_subtasks(
230235 for subtask , priority in zip (subtasks , priorities ):
231236 if isinstance (subtask , str ):
232237 try :
233- subtask = self ._subtask_caches [subtask ].subtask
234- except KeyError :
235238 subtask = self ._subtask_executions [subtask ].subtask
239+ except KeyError :
240+ subtask = self ._subtask_caches [subtask ].subtask
236241 try :
237- info = self ._subtask_executions [subtask .subtask_id ]
238- if info .result .status not in (
242+ if subtask .subtask_id in self ._subtask_executions :
243+ result = self ._subtask_executions [subtask .subtask_id ].result
244+ else :
245+ result = self ._subtask_result_cache [subtask .subtask_id ]
246+ if result .status not in (
239247 SubtaskStatus .cancelled ,
240248 SubtaskStatus .errored ,
241249 ):
@@ -249,7 +257,6 @@ async def submit_subtasks(
249257 supervisor_address = supervisor_address ,
250258 band_name = band_name ,
251259 )
252- self ._subtask_caches .pop (subtask .subtask_id , None )
253260 self ._subtask_executions [subtask .subtask_id ] = subtask_info
254261 put_delays .append (
255262 self ._prepare_queue_ref .put .delay (
@@ -322,7 +329,7 @@ async def cancel_subtasks(
322329 continue
323330 if not subtask_info .result .status .is_done :
324331 self ._fill_result_with_exc (subtask_info , exc_cls = asyncio .CancelledError )
325- infos_to_report .append (subtask_info )
332+ infos_to_report .append (subtask_info )
326333 await self ._report_subtask_results (infos_to_report )
327334
328335 async def wait_subtasks (self , subtask_ids : List [str ]):
@@ -488,6 +495,7 @@ async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo):
488495 subtask_info ,
489496 max_retries = subtask_info .max_retries if subtask .retryable else 0 ,
490497 )
498+ self ._subtask_result_cache [subtask .subtask_id ] = subtask_info .result
491499 except Exception as ex : # noqa: E722 # nosec # pylint: disable=bare-except
492500 if not subtask .retryable :
493501 unretryable_op = [
@@ -654,6 +662,5 @@ async def _forward_subtask_info(self, subtask_info: SubtaskExecutionInfo):
654662 await self ._execution_queue_ref .put (
655663 subtask_id , subtask_info .band_name , subtask_info .priority
656664 )
657- self .uncache_subtasks ([subtask_id ])
658665 except PrepareFastFailed :
659666 self ._subtask_executions .pop (subtask_id , None )
0 commit comments