diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 6179a31aaf..0ed3e2455d 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, overload from pydantic_graph import BaseNode, End, GraphRunContext -from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem +from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem from pydantic_graph.beta.step import NodeStep from . import ( @@ -181,7 +181,7 @@ async def __anext__( return self._task_to_node(task) def _task_to_node( - self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] + self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest] ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: if isinstance(task, Sequence) and len(task) == 1: first_task = task[0] @@ -197,8 +197,8 @@ def _task_to_node( return End(task.value) raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover - def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask: - return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=()) + def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest: + return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=()) async def next( self, diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index a53705a369..f6aeab9a34 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -8,8 +8,7 @@ from __future__ import annotations as _annotations import sys -import uuid -from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload @@ -22,7 +21,7 @@ from pydantic_graph import exceptions from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span from pydantic_graph.beta.decision import Decision -from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID +from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext from pydantic_graph.beta.node import ( EndNode, @@ -306,14 +305,13 @@ def __str__(self) -> str: @dataclass -class GraphTask: - """A single task representing the execution of a node in the graph. +class GraphTaskRequest: + """A request to run a task representing the execution of a node in the graph. - GraphTask encapsulates all the information needed to execute a specific + GraphTaskRequest encapsulates all the information needed to execute a specific node, including its inputs and the fork context it's executing within. """ - # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself node_id: NodeID """The ID of the node to execute.""" @@ -326,9 +324,26 @@ class GraphTask: Used by the GraphRun to decide when to proceed through joins. """ - task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False) + +@dataclass +class GraphTask(GraphTaskRequest): + """A task representing the execution of a node in the graph. + + GraphTask encapsulates all the information needed to execute a specific + node, including its inputs and the fork context it's executing within, + and has a unique ID to identify the task within the graph run. + """ + + task_id: TaskID = field(repr=False) """Unique identifier for this task.""" + @staticmethod + def from_request(request: GraphTaskRequest, get_task_id: Callable[[], TaskID]) -> GraphTask: + # Don't call the get_task_id callable, this is already a task + if isinstance(request, GraphTask): + return request + return GraphTask(request.node_id, request.inputs, request.fork_stack, get_task_id()) + class GraphRun(Generic[StateT, DepsT, OutputT]): """A single execution instance of a graph. @@ -378,12 +393,20 @@ def __init__( self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None """The next item to be processed.""" - run_id = GraphRunID(str(uuid.uuid4())) - initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),) - self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) + self._next_task_id = 0 + self._next_node_run_id = 0 + initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),) + self._first_task = GraphTask( + node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id() + ) self._iterator_task_group = create_task_group() self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT]( - self.graph, self.state, self.deps, self._iterator_task_group + self.graph, + self.state, + self.deps, + self._iterator_task_group, + self._get_next_node_run_id, + self._get_next_task_id, ) self._iterator = self._iterator_instance.iter_graph(self._first_task) @@ -449,7 +472,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]: return self._next async def next( - self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None + self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None ) -> EndMarker[OutputT] | Sequence[GraphTask]: """Advance the graph execution by one step. @@ -467,7 +490,10 @@ async def next( # if `next` is called before the `first_node` has run. await anext(self) if value is not None: - self._next = value + if isinstance(value, EndMarker): + self._next = value + else: + self._next = [GraphTask.from_request(gtr, self._get_next_task_id) for gtr in value] return await anext(self) @property @@ -490,6 +516,16 @@ def output(self) -> OutputT | None: return self._next.value return None + def _get_next_task_id(self) -> TaskID: + next_id = TaskID(f'task:{self._next_task_id}') + self._next_task_id += 1 + return next_id + + def _get_next_node_run_id(self) -> NodeRunID: + next_id = NodeRunID(f'task:{self._next_node_run_id}') + self._next_node_run_id += 1 + return next_id + @dataclass class _GraphTaskAsyncIterable: @@ -510,6 +546,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]): state: StateT deps: DepsT task_group: TaskGroup + get_next_node_run_id: Callable[[], NodeRunID] + get_next_task_id: Callable[[], TaskID] cancel_scopes: dict[TaskID, CancelScope] = field(init=False) active_tasks: dict[TaskID, GraphTask] = field(init=False) @@ -522,6 +560,7 @@ def __post_init__(self): self.active_tasks = {} self.active_reducers = {} self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]() + self._next_node_run_id = 1 async def iter_graph( # noqa C901 self, first_task: GraphTask @@ -782,12 +821,12 @@ def _handle_node( fork_stack: ForkStack, ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: if isinstance(next_node, StepNode): - return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] + return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())] elif isinstance(next_node, JoinNode): return JoinItem(next_node.join.id, next_node.inputs, fork_stack) elif isinstance(next_node, BaseNode): node_step = NodeStep(next_node.__class__) - return [GraphTask(node_step.id, next_node, fork_stack)] + return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())] elif isinstance(next_node, End): return EndMarker(next_node.data) else: @@ -821,7 +860,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen 'These markers should be removed from paths during graph building' ) if isinstance(item, DestinationMarker): - return [GraphTask(item.destination_id, inputs, fork_stack)] + return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())] elif isinstance(item, TransformMarker): inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs)) return self._handle_path(path.next_path, inputs, fork_stack) @@ -853,7 +892,7 @@ def _handle_fork_edges( ) # this should have already been ensured during graph building new_tasks: list[GraphTask] = [] - node_run_id = NodeRunID(str(uuid.uuid4())) + node_run_id = self.get_next_node_run_id() if node.is_map: # If the map specifies a downstream join id, eagerly create a join state for it if (join_id := node.downstream_join_id) is not None: diff --git a/pydantic_graph/pydantic_graph/beta/id_types.py b/pydantic_graph/pydantic_graph/beta/id_types.py index d777a1846e..9ec9056ed4 100644 --- a/pydantic_graph/pydantic_graph/beta/id_types.py +++ b/pydantic_graph/pydantic_graph/beta/id_types.py @@ -24,9 +24,6 @@ ForkID = NodeID """Alias for NodeId when referring to fork nodes.""" -GraphRunID = NewType('GraphRunID', str) -"""Unique identifier for a complete graph execution run.""" - TaskID = NewType('TaskID', str) """Unique identifier for a task within the graph execution.""" diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py index dcb10267ce..e1bf25fe57 100644 --- a/tests/graph/beta/test_graph_iteration.py +++ b/tests/graph/beta/test_graph_iteration.py @@ -8,7 +8,7 @@ import pytest from pydantic_graph.beta import GraphBuilder, StepContext -from pydantic_graph.beta.graph import EndMarker, GraphTask +from pydantic_graph.beta.graph import EndMarker, GraphTask, GraphTaskRequest from pydantic_graph.beta.id_types import NodeID from pydantic_graph.beta.join import reduce_list_append @@ -400,7 +400,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int: # Get the fork_stack from the EndMarker's source fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else () - new_task = GraphTask( + new_task = GraphTaskRequest( node_id=NodeID('second_step'), inputs=event.value, fork_stack=fork_stack, diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 938e123077..9a315a82b9 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -49,6 +49,8 @@ from pydantic_ai.run import AgentRunResult from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition from pydantic_ai.usage import RequestUsage +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.join import reduce_list_append try: from temporalio import workflow @@ -2228,3 +2230,108 @@ async def test_fastmcp_toolset(allow_model_requests: None, client: Client): assert output == snapshot( 'The `pydantic/pydantic-ai` repository is a Python agent framework crafted for developing production-grade Generative AI applications. It emphasizes type safety, model-agnostic design, and extensibility. The framework supports various LLM providers, manages agent workflows using graph-based execution, and ensures structured, reliable LLM outputs. Key packages include core framework components, graph execution engines, evaluation tools, and example applications.' ) + + +# ============================================================================ +# Beta Graph API Tests - Tests for running pydantic-graph beta API in Temporal +# ============================================================================ + + +@dataclass +class GraphState: + """State for the graph execution test.""" + + values: list[int] = field(default_factory=list) + + +# Create a graph with parallel execution using the beta API +graph_builder = GraphBuilder( + name='parallel_test_graph', + state_type=GraphState, + input_type=int, + output_type=list[int], +) + + +@graph_builder.step +async def source(ctx: StepContext[GraphState, None, int]) -> int: + """Source step that passes through the input value.""" + return ctx.inputs + + +@graph_builder.step +async def multiply_by_two(ctx: StepContext[GraphState, None, int]) -> int: + """Multiply input by 2.""" + return ctx.inputs * 2 + + +@graph_builder.step +async def multiply_by_three(ctx: StepContext[GraphState, None, int]) -> int: + """Multiply input by 3.""" + return ctx.inputs * 3 + + +@graph_builder.step +async def multiply_by_four(ctx: StepContext[GraphState, None, int]) -> int: + """Multiply input by 4.""" + return ctx.inputs * 4 + + +# Create a join to collect results +result_collector = graph_builder.join(reduce_list_append, initial_factory=list[int]) + +# Build the graph with parallel edges (broadcast pattern) +graph_builder.add( + graph_builder.edge_from(graph_builder.start_node).to(source), + # Broadcast: send value to all three parallel steps + graph_builder.edge_from(source).to(multiply_by_two, multiply_by_three, multiply_by_four), + # Collect all results + graph_builder.edge_from(multiply_by_two, multiply_by_three, multiply_by_four).to(result_collector), + graph_builder.edge_from(result_collector).to(graph_builder.end_node), +) + +parallel_test_graph = graph_builder.build() + + +@workflow.defn +class ParallelGraphWorkflow: + """Workflow that executes a graph with parallel task execution.""" + + @workflow.run + async def run(self, input_value: int) -> list[int]: + """Run the parallel graph workflow. + + Args: + input_value: The input number to process + + Returns: + List of results from parallel execution + """ + result = await parallel_test_graph.run( + state=GraphState(), + inputs=input_value, + ) + return result + + +async def test_beta_graph_parallel_execution_in_workflow(client: Client): + """Test that beta graph API with parallel execution works in Temporal workflows. + + This test verifies the fix for the bug where parallel task execution in graphs + wasn't working properly with Temporal workflows due to GraphTask/GraphTaskRequest + serialization issues. + """ + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[ParallelGraphWorkflow], + ): + output = await client.execute_workflow( + ParallelGraphWorkflow.run, + args=[10], + id=ParallelGraphWorkflow.__name__, + task_queue=TASK_QUEUE, + ) + # Results can be in any order due to parallel execution + # 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40 + assert sorted(output) == [20, 30, 40]