1616import functools
1717import logging
1818import operator
19+ import sys
1920from dataclasses import dataclass
2021from typing import List , Dict , Any , Set , Callable
2122from .....core import ChunkGraph , Chunk , TileContext
4950 ExecutionChunkResult ,
5051 register_executor_cls ,
5152)
52- from .config import RayExecutionConfig
53+ from .config import RayExecutionConfig , IN_RAY_CI
5354from .context import (
5455 RayExecutionContext ,
5556 RayExecutionWorkerContext ,
@@ -314,35 +315,54 @@ async def execute_subtask_graph(
314315 ) -> Dict [Chunk , ExecutionChunkResult ]:
315316 if self ._cancelled is True : # pragma: no cover
316317 raise asyncio .CancelledError ()
318+ logger .info ("Stage %s start." , stage_id )
319+ # Make sure each stage use a clean dict.
320+ self ._cur_stage_first_output_object_ref_to_subtask = dict ()
317321
318- def _on_monitor_task_done (fut ):
322+ def _on_monitor_aiotask_done (fut ):
319323 # Print the error of monitor task.
320324 try :
321325 fut .result ()
322326 except asyncio .CancelledError :
323327 pass
328+ except Exception : # pragma: no cover
329+ logger .exception (
330+ "The monitor task of stage %s is done with exception." , stage_id
331+ )
332+ if IN_RAY_CI : # pragma: no cover
333+ logger .warning (
334+ "The process will be exit due to the monitor task exception "
335+ "when MARS_CI_BACKEND=ray."
336+ )
337+ sys .exit (- 1 )
324338
339+ result_meta_keys = {
340+ chunk .key
341+ for chunk in chunk_graph .result_chunks
342+ if not isinstance (chunk .op , Fetch )
343+ }
325344 # Create a monitor task to update progress and collect garbage.
326- monitor_task = asyncio .create_task (
345+ monitor_aiotask = asyncio .create_task (
327346 self ._update_progress_and_collect_garbage (
328- subtask_graph , self ._config .get_subtask_monitor_interval ()
347+ stage_id ,
348+ subtask_graph ,
349+ result_meta_keys ,
350+ self ._config .get_subtask_monitor_interval (),
329351 )
330352 )
331- monitor_task .add_done_callback (_on_monitor_task_done )
353+ monitor_aiotask .add_done_callback (_on_monitor_aiotask_done )
332354
333- def _on_execute_task_done ( fut ):
355+ def _on_execute_aiotask_done ( _ ):
334356 # Make sure the monitor task is cancelled.
335- monitor_task .cancel ()
357+ monitor_aiotask .cancel ()
336358 # Just use `self._cur_stage_tile_progress` as current stage progress
337359 # because current stage is completed, its progress is 1.0.
338360 self ._cur_stage_progress = 1.0
339361 self ._pre_all_stages_progress += self ._cur_stage_tile_progress
340- self ._cur_stage_first_output_object_ref_to_subtask .clear ()
341362
342363 self ._execute_subtask_graph_aiotask = asyncio .current_task ()
343- self ._execute_subtask_graph_aiotask .add_done_callback (_on_execute_task_done )
364+ self ._execute_subtask_graph_aiotask .add_done_callback (_on_execute_aiotask_done )
344365
345- logger .info ("Stage %s start." , stage_id )
346366 task_context = self ._task_context
347367 output_meta_object_refs = []
348368 self ._pre_all_stages_tile_progress = (
@@ -352,11 +372,6 @@ def _on_execute_task_done(fut):
352372 self ._tile_context .get_all_progress () - self ._pre_all_stages_tile_progress
353373 )
354374 logger .info ("Submitting %s subtasks of stage %s." , len (subtask_graph ), stage_id )
355- result_meta_keys = {
356- chunk .key
357- for chunk in chunk_graph .result_chunks
358- if not isinstance (chunk .op , Fetch )
359- }
360375 subtask_max_retries = self ._config .get_subtask_max_retries ()
361376 for subtask in subtask_graph .topological_iter ():
362377 subtask_chunk_graph = subtask .chunk_graph
@@ -555,7 +570,11 @@ def _get_subtask_output_keys(chunk_graph: ChunkGraph):
555570 return output_keys .keys ()
556571
557572 async def _update_progress_and_collect_garbage (
558- self , subtask_graph : SubtaskGraph , interval_seconds : float
573+ self ,
574+ stage_id : str ,
575+ subtask_graph : SubtaskGraph ,
576+ result_meta_keys : Set [str ],
577+ interval_seconds : float ,
559578 ):
560579 object_ref_to_subtask = self ._cur_stage_first_output_object_ref_to_subtask
561580 total = len (subtask_graph )
@@ -579,7 +598,7 @@ def gc():
579598 # Iterate the completed subtasks once.
580599 subtask = completed_subtasks [i ]
581600 i += 1
582- logger .debug ("GC: %s" , subtask )
601+ logger .debug ("GC[stage=%s] : %s" , stage_id , subtask )
583602
584603 # Note: There may be a scenario in which delayed gc occurs.
585604 # When a subtask has more than one predecessor, like A, B,
@@ -595,15 +614,23 @@ def gc():
595614 ):
596615 yield
597616 for chunk in pred .chunk_graph .results :
598- self ._task_context .pop (chunk .key , None )
617+ chunk_key = chunk .key
618+ # We need to check the GC chunk key is not in the
619+ # result meta keys, because there are some special
620+ # cases that the result meta keys are not the leaves.
621+ #
622+ # example: test_cut_execution
623+ if chunk_key not in result_meta_keys :
624+ logger .debug ("GC[stage=%s]: %s" , stage_id , chunk )
625+ self ._task_context .pop (chunk_key , None )
599626 gc_subtasks .add (pred )
600627
601628 # TODO(fyrestone): Check the remaining self._task_context.keys()
602629 # in the result subtasks
603630
604631 collect_garbage = gc ()
605632
606- while len (completed_subtasks ) != total :
633+ while len (completed_subtasks ) < total :
607634 if len (object_ref_to_subtask ) <= 0 : # pragma: no cover
608635 await asyncio .sleep (interval_seconds )
609636
0 commit comments