Skip to content

Commit 96643e9

Browse files
committed
Try fixing issues..?
1 parent af888cc commit 96643e9

File tree

1 file changed

+11
-6
lines changed
  • pydantic_graph/pydantic_graph/beta

1 file changed

+11
-6
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,12 @@ async def iter_graph( # noqa C901
541541
else:
542542
maybe_overridden_result = yield task_result.result
543543
if isinstance(maybe_overridden_result, EndMarker):
544-
self.task_group.cancel_scope.cancel()
544+
# If we got an end marker, this task is definitely done, and we're ready to
545+
# start cleaning everything up
546+
await self._finish_task(task_result.source.task_id)
547+
if self.active_tasks:
548+
# Cancel the remaining tasks
549+
self.task_group.cancel_scope.cancel()
545550
return
546551
elif isinstance(maybe_overridden_result, JoinItem):
547552
result = maybe_overridden_result
@@ -600,11 +605,11 @@ async def iter_graph( # noqa C901
600605
new_task_ids = {t.task_id for t in maybe_overridden_result}
601606
for t in task_result.result:
602607
if t.task_id not in new_task_ids:
603-
await self._finish_task(t.task_id, t.node_id)
608+
await self._finish_task(t.task_id)
604609
self._handle_execution_request(maybe_overridden_result)
605610

606611
if task_result.source_is_finished:
607-
await self._finish_task(task_result.source.task_id, task_result.source.node_id)
612+
await self._finish_task(task_result.source.task_id)
608613

609614
if not self.active_tasks:
610615
# if there are no active tasks, we'll be waiting forever for the next result..
@@ -677,7 +682,7 @@ async def iter_graph( # noqa C901
677682
# Same note as above about how this is theoretically reachable but we should
678683
# just get coverage by unifying the code paths
679684
if t.task_id not in new_task_ids: # pragma: no cover
680-
await self._finish_task(t.task_id, t.node_id)
685+
await self._finish_task(t.task_id)
681686
self._handle_execution_request(maybe_overridden_result)
682687
except GeneratorExit:
683688
self.task_group.cancel_scope.cancel()
@@ -687,7 +692,7 @@ async def iter_graph( # noqa C901
687692
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
688693
)
689694

690-
async def _finish_task(self, task_id: TaskID, node_id: str) -> None:
695+
async def _finish_task(self, task_id: TaskID) -> None:
691696
# node_id is just included for debugging right now
692697
scope = self.cancel_scopes.pop(task_id, None)
693698
if scope is not None:
@@ -905,7 +910,7 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
905910
else:
906911
pass
907912
for task_id in task_ids_to_cancel:
908-
await self._finish_task(task_id, 'sibling')
913+
await self._finish_task(task_id)
909914

910915

911916
def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:

0 commit comments

Comments
 (0)