@@ -244,16 +244,14 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType):
244244
245245 async def finish_subtasks (
246246 self ,
247- subtask_results : List [SubtaskResult ],
247+ subtask_ids : List [str ],
248248 bands : List [BandType ] = None ,
249249 schedule_next : bool = True ,
250250 ):
251- subtask_ids = [result .subtask_id for result in subtask_results ]
252251 logger .debug ("Finished subtasks %s." , subtask_ids )
253252 band_tasks = defaultdict (lambda : 0 )
254253 bands = bands or [None ] * len (subtask_ids )
255- for result , subtask_band in zip (subtask_results , bands ):
256- subtask_id = result .subtask_id
254+ for subtask_id , subtask_band in zip (subtask_ids , bands ):
257255 subtask_info = self ._subtask_infos .get (subtask_id , None )
258256
259257 if subtask_info is not None :
@@ -265,13 +263,16 @@ async def finish_subtasks(
265263 "stage_id" : subtask_info .subtask .stage_id ,
266264 },
267265 )
268- self ._subtask_summaries [subtask_id ] = subtask_info .to_summary (
269- is_finished = True ,
270- is_cancelled = result .status == SubtaskStatus .cancelled ,
271- )
266+ if subtask_id not in self ._subtask_summaries :
267+ summary_kw = dict (is_finished = True )
268+ if subtask_info .cancel_pending :
269+ summary_kw ["is_cancelled" ] = True
270+ self ._subtask_summaries [subtask_id ] = subtask_info .to_summary (
271+ ** summary_kw
272+ )
272273 subtask_info .end_time = time .time ()
273274 self ._speculation_execution_scheduler .finish_subtask (subtask_info )
274- # Cancel subtask on other bands.
275+ # Cancel subtask on other bands.
275276 aio_task = subtask_info .band_futures .pop (subtask_band , None )
276277 if aio_task :
277278 yield aio_task
@@ -321,7 +322,7 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
321322 if info .cancel_pending :
322323 res_release_delays .append (
323324 self ._global_resource_ref .release_subtask_resource .delay (
324- band , info . subtask . session_id , info . subtask . subtask_id
325+ band , self . _session_id , subtask_id
325326 )
326327 )
327328 continue
@@ -330,6 +331,12 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
330331 "Subtask %s is not in added subtasks set, it may be finished or canceled, skip it." ,
331332 subtask_id ,
332333 )
334+ # in case resource already allocated, do deallocate
335+ res_release_delays .append (
336+ self ._global_resource_ref .release_subtask_resource .delay (
337+ band , self ._session_id , subtask_id
338+ )
339+ )
333340 continue
334341 band_to_subtask_ids [band ].append (subtask_id )
335342
@@ -414,9 +421,8 @@ async def cancel_task_in_band(band):
414421
415422 info = self ._subtask_infos [subtask_id ]
416423 info .cancel_pending = True
417- raw_tasks_to_cancel = list (info .band_futures .values ())
418424
419- if not raw_tasks_to_cancel :
425+ if not info . band_futures :
420426 # not submitted yet: mark subtasks as cancelled
421427 result = SubtaskResult (
422428 subtask_id = info .subtask .subtask_id ,
@@ -435,13 +441,13 @@ async def cancel_task_in_band(band):
435441 )
436442 band_to_futures [band ].append (future )
437443
438- for band in band_to_futures :
439- cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
440-
444+ # Dequeue first as it is possible to leak subtasks from queues
441445 if queued_subtask_ids :
442- # Don't use `finish_subtasks` because it may remove queued
443446 await self ._queueing_ref .remove_queued_subtasks (queued_subtask_ids )
444447
448+ for band in band_to_futures :
449+ cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
450+
445451 if cancel_tasks :
446452 yield asyncio .gather (* cancel_tasks )
447453
0 commit comments