-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix bug with running graphs in temporal workflows #3460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
7f6c184
Fix bug with running graphs in temporal workflows
dmontagu 8382fe8
Merge branch 'main' into temporal-graph-fix
dmontagu 99afd08
Satisfy lint
DouweM 40b8f9a
Reuse task IDs when possible
dmontagu 856260b
Fix failing test
dmontagu fcb7f0c
Merge branch 'main' into temporal-graph-fix
dmontagu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,7 @@ | |
| from __future__ import annotations as _annotations | ||
|
|
||
| import sys | ||
| import uuid | ||
| from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence | ||
| from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence | ||
| from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager | ||
| from dataclasses import dataclass, field | ||
| from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload | ||
|
|
@@ -22,7 +21,7 @@ | |
| from pydantic_graph import exceptions | ||
| from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span | ||
| from pydantic_graph.beta.decision import Decision | ||
| from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID | ||
| from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID | ||
| from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext | ||
| from pydantic_graph.beta.node import ( | ||
| EndNode, | ||
|
|
@@ -306,14 +305,13 @@ def __str__(self) -> str: | |
|
|
||
|
|
||
| @dataclass | ||
| class GraphTask: | ||
| """A single task representing the execution of a node in the graph. | ||
| class GraphTaskRequest: | ||
| """A request to run a task representing the execution of a node in the graph. | ||
|
|
||
| GraphTask encapsulates all the information needed to execute a specific | ||
| GraphTaskRequest encapsulates all the information needed to execute a specific | ||
| node, including its inputs and the fork context it's executing within. | ||
| """ | ||
|
|
||
| # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself | ||
| node_id: NodeID | ||
| """The ID of the node to execute.""" | ||
|
|
||
|
|
@@ -326,7 +324,29 @@ class GraphTask: | |
| Used by the GraphRun to decide when to proceed through joins. | ||
| """ | ||
|
|
||
| task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False) | ||
|
|
||
| @dataclass | ||
| class GraphTask(GraphTaskRequest): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't you say you were going to make this not be a subclass of the other? |
||
| """A task representing the execution of a node in the graph. | ||
|
|
||
| GraphTask encapsulates all the information needed to execute a specific | ||
| node, including its inputs and the fork context it's executing within, | ||
| and has a unique ID to identify the task within the graph run. | ||
| """ | ||
|
|
||
| node_id: NodeID | ||
| """The ID of the node to execute.""" | ||
|
|
||
| inputs: Any | ||
| """The input data for the node.""" | ||
|
|
||
| fork_stack: ForkStack = field(repr=False) | ||
| """Stack of forks that have been entered. | ||
|
|
||
| Used by the GraphRun to decide when to proceed through joins. | ||
| """ | ||
|
|
||
| task_id: TaskID | ||
| """Unique identifier for this task.""" | ||
|
|
||
|
|
||
|
|
@@ -378,12 +398,20 @@ def __init__( | |
| self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None | ||
| """The next item to be processed.""" | ||
|
|
||
| run_id = GraphRunID(str(uuid.uuid4())) | ||
| initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),) | ||
| self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) | ||
| self._next_task_id = 0 | ||
| self._next_node_run_id = 0 | ||
| initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),) | ||
| self._first_task = GraphTask( | ||
| node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id() | ||
| ) | ||
| self._iterator_task_group = create_task_group() | ||
| self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT]( | ||
| self.graph, self.state, self.deps, self._iterator_task_group | ||
| self.graph, | ||
| self.state, | ||
| self.deps, | ||
| self._iterator_task_group, | ||
| self._get_next_node_run_id, | ||
| self._get_next_task_id, | ||
| ) | ||
| self._iterator = self._iterator_instance.iter_graph(self._first_task) | ||
|
|
||
|
|
@@ -449,7 +477,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]: | |
| return self._next | ||
|
|
||
| async def next( | ||
| self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None | ||
| self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None | ||
| ) -> EndMarker[OutputT] | Sequence[GraphTask]: | ||
| """Advance the graph execution by one step. | ||
|
|
||
|
|
@@ -467,7 +495,10 @@ async def next( | |
| # if `next` is called before the `first_node` has run. | ||
| await anext(self) | ||
| if value is not None: | ||
| self._next = value | ||
| if isinstance(value, EndMarker): | ||
| self._next = value | ||
| else: | ||
| self._next = [GraphTask(gt.node_id, gt.inputs, gt.fork_stack, self._get_next_task_id()) for gt in value] | ||
| return await anext(self) | ||
|
|
||
| @property | ||
|
|
@@ -490,6 +521,16 @@ def output(self) -> OutputT | None: | |
| return self._next.value | ||
| return None | ||
|
|
||
| def _get_next_task_id(self) -> TaskID: | ||
| next_id = TaskID(f'task:{self._next_task_id}') | ||
| self._next_task_id += 1 | ||
| return next_id | ||
|
|
||
| def _get_next_node_run_id(self) -> NodeRunID: | ||
| next_id = NodeRunID(f'task:{self._next_node_run_id}') | ||
| self._next_node_run_id += 1 | ||
| return next_id | ||
|
|
||
|
|
||
| @dataclass | ||
| class _GraphTaskAsyncIterable: | ||
|
|
@@ -510,6 +551,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]): | |
| state: StateT | ||
| deps: DepsT | ||
| task_group: TaskGroup | ||
| get_next_node_run_id: Callable[[], NodeRunID] | ||
| get_next_task_id: Callable[[], TaskID] | ||
|
|
||
| cancel_scopes: dict[TaskID, CancelScope] = field(init=False) | ||
| active_tasks: dict[TaskID, GraphTask] = field(init=False) | ||
|
|
@@ -522,6 +565,7 @@ def __post_init__(self): | |
| self.active_tasks = {} | ||
| self.active_reducers = {} | ||
| self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]() | ||
| self._next_node_run_id = 1 | ||
|
|
||
| async def iter_graph( # noqa C901 | ||
| self, first_task: GraphTask | ||
|
|
@@ -782,12 +826,12 @@ def _handle_node( | |
| fork_stack: ForkStack, | ||
| ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: | ||
| if isinstance(next_node, StepNode): | ||
| return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] | ||
| return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())] | ||
| elif isinstance(next_node, JoinNode): | ||
| return JoinItem(next_node.join.id, next_node.inputs, fork_stack) | ||
| elif isinstance(next_node, BaseNode): | ||
| node_step = NodeStep(next_node.__class__) | ||
| return [GraphTask(node_step.id, next_node, fork_stack)] | ||
| return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())] | ||
| elif isinstance(next_node, End): | ||
| return EndMarker(next_node.data) | ||
| else: | ||
|
|
@@ -821,7 +865,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen | |
| 'These markers should be removed from paths during graph building' | ||
| ) | ||
| if isinstance(item, DestinationMarker): | ||
| return [GraphTask(item.destination_id, inputs, fork_stack)] | ||
| return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())] | ||
| elif isinstance(item, TransformMarker): | ||
| inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs)) | ||
| return self._handle_path(path.next_path, inputs, fork_stack) | ||
|
|
@@ -853,7 +897,7 @@ def _handle_fork_edges( | |
| ) # this should have already been ensured during graph building | ||
|
|
||
| new_tasks: list[GraphTask] = [] | ||
| node_run_id = NodeRunID(str(uuid.uuid4())) | ||
| node_run_id = self.get_next_node_run_id() | ||
| if node.is_map: | ||
| # If the map specifies a downstream join id, eagerly create a join state for it | ||
| if (join_id := node.downstream_join_id) is not None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this needed to be changed