2222from .... import oscar as mo
2323from ....lib .aio import alru_cache
2424from ....metrics import Metrics
25- from ....oscar .backends .context import ProfilingContext
2625from ....oscar .errors import MarsError
27- from ....oscar .profiling import ProfilingData , MARS_ENABLE_PROFILING
2826from ....typing import BandType
29- from ....utils import dataslots , Timer
27+ from ....utils import dataslots
3028from ...subtask import Subtask , SubtaskResult , SubtaskStatus
3129from ...task import TaskAPI
3230from ..core import SubtaskScheduleSummary
@@ -127,14 +125,6 @@ async def __post_create__(self):
127125 )
128126 await self ._speculation_execution_scheduler .start ()
129127
130- async def dump_running ():
131- while True :
132- if self ._subtask_infos :
133- logger .warning ("RUNNING: %r" , list (self ._subtask_infos ))
134- await asyncio .sleep (5 )
135-
136- asyncio .create_task (dump_running ())
137-
138128 async def __pre_destroy__ (self ):
139129 await self ._speculation_execution_scheduler .stop ()
140130
@@ -186,7 +176,7 @@ async def _handle_subtask_result(
186176 self , info : SubtaskScheduleInfo , result : SubtaskResult , band : BandType
187177 ):
188178 subtask_id = info .subtask .subtask_id
189- async with redirect_subtask_errors (self , [info .subtask ]):
179+ async with redirect_subtask_errors (self , [info .subtask ], reraise = False ):
190180 try :
191181 info .band_futures [band ].set_result (result )
192182 if result .error is not None :
@@ -262,9 +252,9 @@ async def finish_subtasks(
262252
263253 if subtask_info is not None :
264254 if subtask_band is not None :
265- logger . warning ( "BEFORE await self._handle_subtask_result(subtask_info, result, subtask_band)" )
266- await self . _handle_subtask_result ( subtask_info , result , subtask_band )
267- logger . warning ( "AFTER await self._handle_subtask_result(subtask_info, result, subtask_band)" )
255+ await self ._handle_subtask_result (
256+ subtask_info , result , subtask_band
257+ )
268258
269259 self ._finished_subtask_count .record (
270260 1 ,
@@ -275,16 +265,15 @@ async def finish_subtasks(
275265 },
276266 )
277267 self ._subtask_summaries [subtask_id ] = subtask_info .to_summary (
278- is_finished = True , is_cancelled = result .status == SubtaskStatus .cancelled
268+ is_finished = True ,
269+ is_cancelled = result .status == SubtaskStatus .cancelled ,
279270 )
280271 subtask_info .end_time = time .time ()
281272 self ._speculation_execution_scheduler .finish_subtask (subtask_info )
282273 # Cancel subtask on other bands.
283274 aio_task = subtask_info .band_futures .pop (subtask_band , None )
284275 if aio_task :
285- logger .warning ("BEFORE await aio_task" )
286276 await aio_task
287- logger .warning ("AFTER await aio_task" )
288277 if schedule_next :
289278 band_tasks [subtask_band ] += 1
290279 if subtask_info .band_futures :
@@ -304,7 +293,6 @@ async def finish_subtasks(
304293 if schedule_next :
305294 for band in subtask_info .band_futures .keys ():
306295 band_tasks [band ] += 1
307- # await self._queueing_ref.remove_queued_subtasks(subtask_ids)
308296 if band_tasks :
309297 await self ._queueing_ref .submit_subtasks .tell (dict (band_tasks ))
310298
@@ -345,7 +333,9 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
345333 band_to_subtask_ids [band ].append (subtask_id )
346334
347335 if res_release_delays :
348- await self ._global_resource_ref .release_subtask_resource .batch (* res_release_delays )
336+ await self ._global_resource_ref .release_subtask_resource .batch (
337+ * res_release_delays
338+ )
349339
350340 for band , subtask_ids in band_to_subtask_ids .items ():
351341 asyncio .create_task (self ._submit_subtasks_to_band (band , subtask_ids ))
@@ -386,29 +376,22 @@ async def cancel_subtasks(
386376 subtask_ids ,
387377 kill_timeout ,
388378 )
389- queued_subtask_ids = []
390- single_cancel_tasks = []
391379
392380 task_api = await self ._get_task_api ()
393381
394- async def cancel_single_task (subtask , raw_tasks , cancel_tasks ):
395- if cancel_tasks :
396- await asyncio .wait (cancel_tasks )
397- if raw_tasks :
398- dones , _ = await asyncio .wait (raw_tasks )
399- else :
400- dones = []
401- if not dones or all (fut .cancelled () for fut in dones ):
402- await task_api .set_subtask_result (
403- SubtaskResult (
404- subtask_id = subtask .subtask_id ,
405- session_id = subtask .session_id ,
406- task_id = subtask .task_id ,
407- stage_id = subtask .stage_id ,
408- status = SubtaskStatus .cancelled ,
409- )
410- )
382+ async def cancel_task_in_band (band ):
383+ cancel_delays = band_to_cancel_delays .get (band ) or []
384+ execution_ref = await self ._get_execution_ref (band )
385+ if cancel_delays :
386+ await execution_ref .cancel_subtask .batch (* cancel_delays )
387+ band_futures = band_to_futures .get (band )
388+ if band_futures :
389+ await asyncio .wait (band_futures )
411390
391+ queued_subtask_ids = []
392+ cancel_tasks = []
393+ band_to_cancel_delays = defaultdict (list )
394+ band_to_futures = defaultdict (list )
412395 for subtask_id in subtask_ids :
413396 if subtask_id not in self ._subtask_infos :
414397 # subtask may already finished or not submitted at all
@@ -423,35 +406,33 @@ async def cancel_single_task(subtask, raw_tasks, cancel_tasks):
423406 raw_tasks_to_cancel = list (info .band_futures .values ())
424407
425408 if not raw_tasks_to_cancel :
426- queued_subtask_ids .append (subtask_id )
427- single_cancel_tasks .append (
428- asyncio .create_task (
429- cancel_single_task (info .subtask , [], [])
430- )
409+ # not submitted yet: mark subtasks as cancelled
410+ result = SubtaskResult (
411+ subtask_id = info .subtask .subtask_id ,
412+ session_id = info .subtask .session_id ,
413+ task_id = info .subtask .task_id ,
414+ stage_id = info .subtask .stage_id ,
415+ status = SubtaskStatus .cancelled ,
431416 )
417+ cancel_tasks .append (task_api .set_subtask_result (result ))
418+ queued_subtask_ids .append (subtask_id )
432419 else :
433- cancel_tasks = []
434- for band in info .band_futures .keys ():
420+ for band , future in info .band_futures .items ():
435421 execution_ref = await self ._get_execution_ref (band )
436- cancel_tasks .append (
437- asyncio .create_task (
438- execution_ref .cancel_subtask (
439- subtask_id , kill_timeout = kill_timeout
440- )
441- )
422+ band_to_cancel_delays [band ].append (
423+ execution_ref .cancel_subtask .delay (subtask_id , kill_timeout )
442424 )
443- single_cancel_tasks .append (
444- asyncio .create_task (
445- cancel_single_task (
446- info .subtask , raw_tasks_to_cancel , cancel_tasks
447- )
448- )
449- )
425+ band_to_futures [band ].append (future )
426+
427+ for band in band_to_futures :
428+ cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
429+
450430 if queued_subtask_ids :
451431 # Don't use `finish_subtasks` because it may remove queued
452432 await self ._queueing_ref .remove_queued_subtasks (queued_subtask_ids )
453- if single_cancel_tasks :
454- yield asyncio .wait (single_cancel_tasks )
433+
434+ if cancel_tasks :
435+ yield asyncio .gather (* cancel_tasks )
455436
456437 for subtask_id in subtask_ids :
457438 info = self ._subtask_infos .pop (subtask_id , None )
0 commit comments