@@ -541,7 +541,12 @@ async def iter_graph( # noqa C901
541541 else :
542542 maybe_overridden_result = yield task_result .result
543543 if isinstance (maybe_overridden_result , EndMarker ):
544- self .task_group .cancel_scope .cancel ()
544+ # If we got an end marker, this task is definitely done, and we're ready to
545+ # start cleaning everything up
546+ await self ._finish_task (task_result .source .task_id )
547+ if self .active_tasks :
548+ # Cancel the remaining tasks
549+ self .task_group .cancel_scope .cancel ()
545550 return
546551 elif isinstance (maybe_overridden_result , JoinItem ):
547552 result = maybe_overridden_result
@@ -600,11 +605,11 @@ async def iter_graph( # noqa C901
600605 new_task_ids = {t .task_id for t in maybe_overridden_result }
601606 for t in task_result .result :
602607 if t .task_id not in new_task_ids :
603- await self ._finish_task (t .task_id , t . node_id )
608+ await self ._finish_task (t .task_id )
604609 self ._handle_execution_request (maybe_overridden_result )
605610
606611 if task_result .source_is_finished :
607- await self ._finish_task (task_result .source .task_id , task_result . source . node_id )
612+ await self ._finish_task (task_result .source .task_id )
608613
609614 if not self .active_tasks :
610615 # if there are no active tasks, we'll be waiting forever for the next result..
@@ -677,7 +682,7 @@ async def iter_graph( # noqa C901
677682 # Same note as above about how this is theoretically reachable but we should
678683 # just get coverage by unifying the code paths
679684 if t .task_id not in new_task_ids : # pragma: no cover
680- await self ._finish_task (t .task_id , t . node_id )
685+ await self ._finish_task (t .task_id )
681686 self ._handle_execution_request (maybe_overridden_result )
682687 except GeneratorExit :
683688 self .task_group .cancel_scope .cancel ()
@@ -687,7 +692,7 @@ async def iter_graph( # noqa C901
687692 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
688693 )
689694
690- async def _finish_task (self , task_id : TaskID , node_id : str ) -> None :
695+ async def _finish_task (self , task_id : TaskID ) -> None :
691696 # node_id is just included for debugging right now
692697 scope = self .cancel_scopes .pop (task_id , None )
693698 if scope is not None :
@@ -905,7 +910,7 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
905910 else :
906911 pass
907912 for task_id in task_ids_to_cancel :
908- await self ._finish_task (task_id , 'sibling' )
913+ await self ._finish_task (task_id )
909914
910915
911916def _is_any_iterable (x : Any ) -> TypeGuard [Iterable [Any ]]:
0 commit comments