@@ -244,16 +244,14 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType):
244244
245245 async def finish_subtasks (
246246 self ,
247- subtask_results : List [SubtaskResult ],
247+ subtask_ids : List [str ],
248248 bands : List [BandType ] = None ,
249249 schedule_next : bool = True ,
250250 ):
251- subtask_ids = [result .subtask_id for result in subtask_results ]
252251 logger .debug ("Finished subtasks %s." , subtask_ids )
253252 band_tasks = defaultdict (lambda : 0 )
254253 bands = bands or [None ] * len (subtask_ids )
255- for result , subtask_band in zip (subtask_results , bands ):
256- subtask_id = result .subtask_id
254+ for subtask_id , subtask_band in zip (subtask_ids , bands ):
257255 subtask_info = self ._subtask_infos .get (subtask_id , None )
258256
259257 if subtask_info is not None :
@@ -265,13 +263,13 @@ async def finish_subtasks(
265263 "stage_id" : subtask_info .subtask .stage_id ,
266264 },
267265 )
268- self . _subtask_summaries [ subtask_id ] = subtask_info . to_summary (
269- is_finished = True ,
270- is_cancelled = result . status == SubtaskStatus . cancelled ,
271- )
266+ if subtask_id not in self . _subtask_summaries :
267+ self . _subtask_summaries [ subtask_id ] = subtask_info . to_summary (
268+ is_finished = True ,
269+ )
272270 subtask_info .end_time = time .time ()
273271 self ._speculation_execution_scheduler .finish_subtask (subtask_info )
274- # Cancel subtask on other bands.
272+ # Cancel subtask on other bands.
275273 aio_task = subtask_info .band_futures .pop (subtask_band , None )
276274 if aio_task :
277275 yield aio_task
@@ -414,9 +412,8 @@ async def cancel_task_in_band(band):
414412
415413 info = self ._subtask_infos [subtask_id ]
416414 info .cancel_pending = True
417- raw_tasks_to_cancel = list (info .band_futures .values ())
418415
419- if not raw_tasks_to_cancel :
416+ if not info . band_futures :
420417 # not submitted yet: mark subtasks as cancelled
421418 result = SubtaskResult (
422419 subtask_id = info .subtask .subtask_id ,
@@ -435,13 +432,13 @@ async def cancel_task_in_band(band):
435432 )
436433 band_to_futures [band ].append (future )
437434
438- for band in band_to_futures :
439- cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
440-
435+ # Dequeue first as it is possible to leak subtasks from queues
441436 if queued_subtask_ids :
442- # Don't use `finish_subtasks` because it may remove queued
443437 await self ._queueing_ref .remove_queued_subtasks (queued_subtask_ids )
444438
439+ for band in band_to_futures :
440+ cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
441+
445442 if cancel_tasks :
446443 yield asyncio .gather (* cancel_tasks )
447444
0 commit comments