diff --git a/examples/pydantic_ai_examples/temporal_graph.py b/examples/pydantic_ai_examples/temporal_graph.py new file mode 100644 index 0000000000..4f80e138b6 --- /dev/null +++ b/examples/pydantic_ai_examples/temporal_graph.py @@ -0,0 +1,321 @@ +"""Example demonstrating pydantic-graph integration with Temporal workflows. + +This example shows how pydantic-graph graphs "just work" inside Temporal workflows, +with TemporalAgent handling model requests and tool calls as durable activities. + +The example implements a research workflow that: +1. Breaks down a complex question into simpler sub-questions +2. Researches each sub-question in parallel +3. Synthesizes the results into a final answer + +To run this example: +1. Start Temporal server locally: + ```sh + brew install temporal + temporal server start-dev + ``` + +2. Run this script: + ```sh + uv run python examples/pydantic_ai_examples/temporal_graph.py + ``` +""" + +from __future__ import annotations + +import asyncio +import uuid +from dataclasses import dataclass + +from pydantic import BaseModel +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker + +from pydantic_ai import Agent +from pydantic_ai.durable_exec.temporal import ( + AgentPlugin, + PydanticAIPlugin, + TemporalAgent, +) +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.join import reduce_list_extend + +# ============================================================================ +# State and Dependencies +# ============================================================================ + + +@dataclass +class ResearchState: + """State that flows through the research graph.""" + + original_question: str + sub_questions: list[str] | None = None + sub_answers: list[str] | None = None + final_answer: str | None = None + + +@dataclass +class ResearchDeps: + """Dependencies for the research workflow (must be serializable for Temporal).""" + + max_sub_questions: int = 3 + + +# ============================================================================ +# Output Models +# ============================================================================ + + +class SubQuestions(BaseModel): + """Model for breaking down a question into sub-questions.""" + + sub_questions: list[str] + + +class Answer(BaseModel): + """Model for a research answer.""" + + answer: str + confidence: float + + +# ============================================================================ +# Agents +# ============================================================================ + +# Agent that breaks down complex questions into simpler sub-questions +question_breaker_agent = Agent( + 'openai:gpt-5-mini', + name='question_breaker', + instructions=( + 'You are an expert at breaking down complex questions into simpler, ' + 'more focused sub-questions that can be researched independently. ' + 'Create questions that cover different aspects of the original question.' + ), + output_type=SubQuestions, +) + +# Agent that researches individual questions +researcher_agent = Agent( + 'openai:gpt-5-mini', + name='researcher', + instructions=( + 'You are a research assistant. Provide clear, accurate, and concise answers ' + 'to questions based on your knowledge. Include confidence level in your response.' + ), + output_type=Answer, +) + +# Agent that synthesizes multiple answers into a comprehensive final answer +synthesizer_agent = Agent( + 'openai:gpt-5-mini', + name='synthesizer', + instructions=( + 'You are an expert at synthesizing multiple pieces of information into ' + 'a coherent, comprehensive answer. Combine the provided answers while ' + 'maintaining accuracy and clarity.' + ), +) + +# Wrap all agents with TemporalAgent for durable execution +temporal_question_breaker = TemporalAgent(question_breaker_agent) +temporal_researcher = TemporalAgent(researcher_agent) +temporal_synthesizer = TemporalAgent(synthesizer_agent) + + +# ============================================================================ +# Graph Definition using Beta API +# ============================================================================ + +# Create the graph builder +g = GraphBuilder( + name='research_workflow', + state_type=ResearchState, + deps_type=ResearchDeps, + input_type=str, # Takes a question string as input + output_type=str, # Returns final answer as string + auto_instrument=True, +) + + +# Step 1: Break down the question into sub-questions +@g.step(node_id='break_down_question', label='Break Down Question') +async def break_down_question( + ctx: StepContext[ResearchState, ResearchDeps, str], +) -> ResearchState: + """Break down the original question into sub-questions using an agent.""" + question = ctx.inputs + + # Use the TemporalAgent to break down the question + result = await temporal_question_breaker.run( + f'Break down this question into {ctx.deps.max_sub_questions} simpler sub-questions: {question}', + ) + + # Update state with sub-questions + return ResearchState( + original_question=question, + sub_questions=result.output.sub_questions, + ) + + +# Step 2: Research each sub-question (will run in parallel via map) +@g.step(node_id='research_sub_question', label='Research Sub-Question') +async def research_sub_question( + ctx: StepContext[ResearchState, ResearchDeps, str], +) -> str: + """Research a single sub-question using an agent.""" + sub_question = ctx.inputs + + # Use the TemporalAgent to research the sub-question + result = await temporal_researcher.run(sub_question) + + # Return the answer as a formatted string + return f'**Q: {sub_question}**\nA: {result.output.answer} (Confidence: {result.output.confidence:.0%})' + + +# Step 3: Join all research results +research_join = g.join( + reducer=reduce_list_extend, + initial=list[str](), +) + + +# Step 4: Synthesize all answers into a final answer +@g.step(node_id='synthesize_answer', label='Synthesize Answer') +async def synthesize_answer( + ctx: StepContext[ResearchState, ResearchDeps, list[str]], +) -> ResearchState: + """Synthesize all research results into a final comprehensive answer.""" + research_results = ctx.inputs + + # Format the research results for the synthesizer + research_summary = '\n\n'.join(research_results) + + # Use the TemporalAgent to synthesize the final answer + result = await temporal_synthesizer.run( + f'Original question: {ctx.state.original_question}\n\n' + f'Research findings:\n{research_summary}\n\n' + 'Please synthesize these findings into a comprehensive answer to the original question.', + ) + + # Update state with final answer + state = ctx.state + state.sub_answers = research_results + state.final_answer = result.output + + return state + + +# Build the graph with edges +g.add( + # Start -> Break down question + g.edge_from(g.start_node).to(break_down_question), + # Break down -> Map over sub-questions for parallel research + g.edge_from(break_down_question) + .transform(lambda ctx: ctx.inputs.sub_questions or []) + .map() + .to(research_sub_question), + # Research results -> Join + g.edge_from(research_sub_question).to(research_join), + # Join -> Synthesize + g.edge_from(research_join).to(synthesize_answer), + # Synthesize -> End + g.edge_from(synthesize_answer) + .transform(lambda ctx: ctx.inputs.final_answer or '') + .to(g.end_node), +) + +# Build the final graph +research_graph = g.build() + + +# ============================================================================ +# Temporal Workflow +# ============================================================================ + + +@workflow.defn +class ResearchWorkflow: + """Temporal workflow that executes the research graph with durable execution.""" + + @workflow.run + async def run(self, question: str, deps: ResearchDeps | None = None) -> str: + """Run the research workflow on a question. + + Args: + question: The question to research + deps: Optional dependencies for the workflow + + Returns: + The final synthesized answer + """ + if deps is None: + deps = ResearchDeps() + + # Execute the pydantic-graph graph - it "just works" in Temporal! + result = await research_graph.run( + state=ResearchState(original_question=question), + deps=deps, + inputs=question, + ) + + return result + + +# ============================================================================ +# Main Execution +# ============================================================================ + + +async def main(): + """Main function to set up worker and execute the workflow.""" + # Monkeypatch uuid.uuid4 to use Temporal's deterministic UUID generation + # This is necessary because pydantic-graph uses uuid.uuid4 internally for task IDs + # Connect to Temporal server + client = await Client.connect( + 'localhost:7233', + plugins=[PydanticAIPlugin()], + ) + + # Create a worker that will execute workflows and activities + async with Worker( + client, + task_queue='research', + workflows=[ResearchWorkflow], + plugins=[ + # Register activities for all three temporal agents + AgentPlugin(temporal_question_breaker), + AgentPlugin(temporal_researcher), + AgentPlugin(temporal_synthesizer), + ], + ): + # Execute the workflow + question = 'What are the key factors that contributed to the success of the Apollo 11 moon landing?' + + print(f'\n{"=" * 80}') + print(f'Research Question: {question}') + print(f'{"=" * 80}\n') + + output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + ResearchWorkflow.run, + args=[question], + id=f'research-{uuid.uuid4()}', + task_queue='research', + ) + + print(f'\n{"=" * 80}') + print('Final Answer:') + print(f'{"=" * 80}\n') + print(output) + print(f'\n{"=" * 80}\n') + + +if __name__ == '__main__': + import logfire + + logfire.instrument_pydantic_ai() + logfire.configure(send_to_logfire=False) + + asyncio.run(main()) diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 6179a31aaf..0ed3e2455d 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, overload from pydantic_graph import BaseNode, End, GraphRunContext -from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem +from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem from pydantic_graph.beta.step import NodeStep from . import ( @@ -181,7 +181,7 @@ async def __anext__( return self._task_to_node(task) def _task_to_node( - self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask] + self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest] ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]: if isinstance(task, Sequence) and len(task) == 1: first_task = task[0] @@ -197,8 +197,8 @@ def _task_to_node( return End(task.value) raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover - def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask: - return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=()) + def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest: + return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=()) async def next( self, diff --git a/pydantic_graph/pydantic_graph/beta/graph.py b/pydantic_graph/pydantic_graph/beta/graph.py index a53705a369..5f83463f34 100644 --- a/pydantic_graph/pydantic_graph/beta/graph.py +++ b/pydantic_graph/pydantic_graph/beta/graph.py @@ -8,8 +8,7 @@ from __future__ import annotations as _annotations import sys -import uuid -from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Sequence from contextlib import AbstractContextManager, AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload @@ -22,7 +21,7 @@ from pydantic_graph import exceptions from pydantic_graph._utils import AbstractSpan, get_traceparent, infer_obj_name, logfire_span from pydantic_graph.beta.decision import Decision -from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, GraphRunID, JoinID, NodeID, NodeRunID, TaskID +from pydantic_graph.beta.id_types import ForkID, ForkStack, ForkStackItem, JoinID, NodeID, NodeRunID, TaskID from pydantic_graph.beta.join import Join, JoinNode, JoinState, ReducerContext from pydantic_graph.beta.node import ( EndNode, @@ -306,14 +305,13 @@ def __str__(self) -> str: @dataclass -class GraphTask: - """A single task representing the execution of a node in the graph. +class GraphTaskRequest: + """A request to run a task representing the execution of a node in the graph. - GraphTask encapsulates all the information needed to execute a specific + GraphTaskRequest encapsulates all the information needed to execute a specific node, including its inputs and the fork context it's executing within. """ - # With our current BaseNode thing, next_node_id and next_node_inputs are merged into `next_node` itself node_id: NodeID """The ID of the node to execute.""" @@ -326,7 +324,29 @@ class GraphTask: Used by the GraphRun to decide when to proceed through joins. """ - task_id: TaskID = field(default_factory=lambda: TaskID(str(uuid.uuid4())), repr=False) + +@dataclass +class GraphTask(GraphTaskRequest): + """A task representing the execution of a node in the graph. + + GraphTask encapsulates all the information needed to execute a specific + node, including its inputs and the fork context it's executing within, + and has a unique ID to identify the task within the graph run. + """ + + node_id: NodeID + """The ID of the node to execute.""" + + inputs: Any + """The input data for the node.""" + + fork_stack: ForkStack = field(repr=False) + """Stack of forks that have been entered. + + Used by the GraphRun to decide when to proceed through joins. + """ + + task_id: TaskID """Unique identifier for this task.""" @@ -378,12 +398,20 @@ def __init__( self._next: EndMarker[OutputT] | Sequence[GraphTask] | None = None """The next item to be processed.""" - run_id = GraphRunID(str(uuid.uuid4())) - initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, NodeRunID(run_id), 0),) - self._first_task = GraphTask(node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack) + self._next_task_id = 0 + self._next_node_run_id = 0 + initial_fork_stack: ForkStack = (ForkStackItem(StartNode.id, self._get_next_node_run_id(), 0),) + self._first_task = GraphTask( + node_id=StartNode.id, inputs=inputs, fork_stack=initial_fork_stack, task_id=self._get_next_task_id() + ) self._iterator_task_group = create_task_group() self._iterator_instance = _GraphIterator[StateT, DepsT, OutputT]( - self.graph, self.state, self.deps, self._iterator_task_group + self.graph, + self.state, + self.deps, + self._iterator_task_group, + self._get_next_node_run_id, + self._get_next_task_id, ) self._iterator = self._iterator_instance.iter_graph(self._first_task) @@ -449,7 +477,7 @@ async def __anext__(self) -> EndMarker[OutputT] | Sequence[GraphTask]: return self._next async def next( - self, value: EndMarker[OutputT] | Sequence[GraphTask] | None = None + self, value: EndMarker[OutputT] | Sequence[GraphTaskRequest] | None = None ) -> EndMarker[OutputT] | Sequence[GraphTask]: """Advance the graph execution by one step. @@ -467,7 +495,10 @@ async def next( # if `next` is called before the `first_node` has run. await anext(self) if value is not None: - self._next = value + if isinstance(value, EndMarker): + self._next = value + else: + self._next = [GraphTask(gt.node_id, gt.inputs, gt.fork_stack, self._get_next_task_id()) for gt in value] return await anext(self) @property @@ -490,6 +521,16 @@ def output(self) -> OutputT | None: return self._next.value return None + def _get_next_task_id(self) -> TaskID: + next_id = TaskID(f'task:{self._next_task_id}') + self._next_task_id += 1 + return next_id + + def _get_next_node_run_id(self) -> NodeRunID: + next_id = NodeRunID(f'task:{self._next_node_run_id}') + self._next_node_run_id += 1 + return next_id + @dataclass class _GraphTaskAsyncIterable: @@ -510,6 +551,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]): state: StateT deps: DepsT task_group: TaskGroup + get_next_node_run_id: Callable[[], NodeRunID] + get_next_task_id: Callable[[], TaskID] cancel_scopes: dict[TaskID, CancelScope] = field(init=False) active_tasks: dict[TaskID, GraphTask] = field(init=False) @@ -522,6 +565,7 @@ def __post_init__(self): self.active_tasks = {} self.active_reducers = {} self.iter_stream_sender, self.iter_stream_receiver = create_memory_object_stream[_GraphTaskResult]() + self._next_node_run_id = 1 async def iter_graph( # noqa C901 self, first_task: GraphTask @@ -782,12 +826,12 @@ def _handle_node( fork_stack: ForkStack, ) -> Sequence[GraphTask] | JoinItem | EndMarker[OutputT]: if isinstance(next_node, StepNode): - return [GraphTask(next_node.step.id, next_node.inputs, fork_stack)] + return [GraphTask(next_node.step.id, next_node.inputs, fork_stack, self.get_next_task_id())] elif isinstance(next_node, JoinNode): return JoinItem(next_node.join.id, next_node.inputs, fork_stack) elif isinstance(next_node, BaseNode): node_step = NodeStep(next_node.__class__) - return [GraphTask(node_step.id, next_node, fork_stack)] + return [GraphTask(node_step.id, next_node, fork_stack, self.get_next_task_id())] elif isinstance(next_node, End): return EndMarker(next_node.data) else: @@ -821,7 +865,7 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen 'These markers should be removed from paths during graph building' ) if isinstance(item, DestinationMarker): - return [GraphTask(item.destination_id, inputs, fork_stack)] + return [GraphTask(item.destination_id, inputs, fork_stack, self.get_next_task_id())] elif isinstance(item, TransformMarker): inputs = item.transform(StepContext(state=self.state, deps=self.deps, inputs=inputs)) return self._handle_path(path.next_path, inputs, fork_stack) @@ -853,7 +897,7 @@ def _handle_fork_edges( ) # this should have already been ensured during graph building new_tasks: list[GraphTask] = [] - node_run_id = NodeRunID(str(uuid.uuid4())) + node_run_id = self.get_next_node_run_id() if node.is_map: # If the map specifies a downstream join id, eagerly create a join state for it if (join_id := node.downstream_join_id) is not None: diff --git a/pydantic_graph/pydantic_graph/beta/id_types.py b/pydantic_graph/pydantic_graph/beta/id_types.py index d777a1846e..9ec9056ed4 100644 --- a/pydantic_graph/pydantic_graph/beta/id_types.py +++ b/pydantic_graph/pydantic_graph/beta/id_types.py @@ -24,9 +24,6 @@ ForkID = NodeID """Alias for NodeId when referring to fork nodes.""" -GraphRunID = NewType('GraphRunID', str) -"""Unique identifier for a complete graph execution run.""" - TaskID = NewType('TaskID', str) """Unique identifier for a task within the graph execution.""" diff --git a/tests/graph/beta/test_graph_iteration.py b/tests/graph/beta/test_graph_iteration.py index dcb10267ce..e1bf25fe57 100644 --- a/tests/graph/beta/test_graph_iteration.py +++ b/tests/graph/beta/test_graph_iteration.py @@ -8,7 +8,7 @@ import pytest from pydantic_graph.beta import GraphBuilder, StepContext -from pydantic_graph.beta.graph import EndMarker, GraphTask +from pydantic_graph.beta.graph import EndMarker, GraphTask, GraphTaskRequest from pydantic_graph.beta.id_types import NodeID from pydantic_graph.beta.join import reduce_list_append @@ -400,7 +400,7 @@ async def second_step(ctx: StepContext[IterState, None, int]) -> int: # Get the fork_stack from the EndMarker's source fork_stack = run.next_task[0].fork_stack if isinstance(run.next_task, list) else () - new_task = GraphTask( + new_task = GraphTaskRequest( node_id=NodeID('second_step'), inputs=event.value, fork_stack=fork_stack, diff --git a/tests/test_temporal.py b/tests/test_temporal.py index e6f655be14..b07dd1ebdf 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -49,6 +49,8 @@ from pydantic_ai.run import AgentRunResult from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition from pydantic_ai.usage import RequestUsage +from pydantic_graph.beta import GraphBuilder, StepContext +from pydantic_graph.beta.join import reduce_list_append try: from temporalio import workflow @@ -2228,3 +2230,108 @@ async def test_fastmcp_toolset(allow_model_requests: None, client: Client): assert output == snapshot( 'The `pydantic/pydantic-ai` repository is a Python agent framework crafted for developing production-grade Generative AI applications. It emphasizes type safety, model-agnostic design, and extensibility. The framework supports various LLM providers, manages agent workflows using graph-based execution, and ensures structured, reliable LLM outputs. Key packages include core framework components, graph execution engines, evaluation tools, and example applications.' ) + + +# ============================================================================ +# Beta Graph API Tests - Tests for running pydantic-graph beta API in Temporal +# ============================================================================ + + +@dataclass +class GraphState: + """State for the graph execution test.""" + + values: list[int] = field(default_factory=list) + + +# Create a graph with parallel execution using the beta API +graph_builder = GraphBuilder( + name='parallel_test_graph', + state_type=GraphState, + input_type=int, + output_type=list[int], +) + + +@graph_builder.step +async def source(ctx: StepContext[GraphState, None, int]) -> int: + """Source step that passes through the input value.""" + return ctx.inputs + + +@graph_builder.step +async def multiply_by_two(ctx: StepContext[GraphState, None, int]) -> int: + """Multiply input by 2.""" + return ctx.inputs * 2 + + +@graph_builder.step +async def multiply_by_three(ctx: StepContext[GraphState, None, int]) -> int: + """Multiply input by 3.""" + return ctx.inputs * 3 + + +@graph_builder.step +async def multiply_by_four(ctx: StepContext[GraphState, None, int]) -> int: + """Multiply input by 4.""" + return ctx.inputs * 4 + + +# Create a join to collect results +result_collector = graph_builder.join(reduce_list_append, initial_factory=list[int]) + +# Build the graph with parallel edges (broadcast pattern) +graph_builder.add( + graph_builder.edge_from(graph_builder.start_node).to(source), + # Broadcast: send value to all three parallel steps + graph_builder.edge_from(source).to(multiply_by_two, multiply_by_three, multiply_by_four), + # Collect all results + graph_builder.edge_from(multiply_by_two, multiply_by_three, multiply_by_four).to(result_collector), + graph_builder.edge_from(result_collector).to(graph_builder.end_node), +) + +parallel_test_graph = graph_builder.build() + + +@workflow.defn +class ParallelGraphWorkflow: + """Workflow that executes a graph with parallel task execution.""" + + @workflow.run + async def run(self, input_value: int) -> list[int]: + """Run the parallel graph workflow. + + Args: + input_value: The input number to process + + Returns: + List of results from parallel execution + """ + result = await parallel_test_graph.run( + state=GraphState(), + inputs=input_value, + ) + return result + + +async def test_beta_graph_parallel_execution_in_workflow(client: Client): + """Test that beta graph API with parallel execution works in Temporal workflows. + + This test verifies the fix for the bug where parallel task execution in graphs + wasn't working properly with Temporal workflows due to GraphTask/GraphTaskRequest + serialization issues. + """ + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[ParallelGraphWorkflow], + ): + output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + ParallelGraphWorkflow.run, + args=[10], + id=ParallelGraphWorkflow.__name__, + task_queue=TASK_QUEUE, + ) + # Results can be in any order due to parallel execution + # 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40 + assert sorted(output) == [20, 30, 40]