Skip to content

Commit 7b10ea7

Browse files
committed
Fix bug with running graphs in temporal workflows
1 parent e72452f commit 7b10ea7

File tree

6 files changed

+496
-27
lines changed

6 files changed

+496
-27
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""Example demonstrating pydantic-graph integration with Temporal workflows.
2+
3+
This example shows how pydantic-graph graphs "just work" inside Temporal workflows,
4+
with TemporalAgent handling model requests and tool calls as durable activities.
5+
6+
The example implements a research workflow that:
7+
1. Breaks down a complex question into simpler sub-questions
8+
2. Researches each sub-question in parallel
9+
3. Synthesizes the results into a final answer
10+
11+
To run this example:
12+
1. Start Temporal server locally:
13+
```sh
14+
brew install temporal
15+
temporal server start-dev
16+
```
17+
18+
2. Run this script:
19+
```sh
20+
uv run python examples/pydantic_ai_examples/temporal_graph.py
21+
```
22+
"""
23+
24+
from __future__ import annotations
25+
26+
import asyncio
27+
import uuid
28+
from dataclasses import dataclass
29+
30+
from pydantic import BaseModel
31+
from temporalio import workflow
32+
from temporalio.client import Client
33+
from temporalio.worker import Worker
34+
35+
from pydantic_ai import Agent
36+
from pydantic_ai.durable_exec.temporal import (
37+
AgentPlugin,
38+
PydanticAIPlugin,
39+
TemporalAgent,
40+
)
41+
from pydantic_graph.beta import GraphBuilder, StepContext
42+
from pydantic_graph.beta.join import reduce_list_extend
43+
44+
# ============================================================================
45+
# State and Dependencies
46+
# ============================================================================
47+
48+
49+
@dataclass
50+
class ResearchState:
51+
"""State that flows through the research graph."""
52+
53+
original_question: str
54+
sub_questions: list[str] | None = None
55+
sub_answers: list[str] | None = None
56+
final_answer: str | None = None
57+
58+
59+
@dataclass
60+
class ResearchDeps:
61+
"""Dependencies for the research workflow (must be serializable for Temporal)."""
62+
63+
max_sub_questions: int = 3
64+
65+
66+
# ============================================================================
67+
# Output Models
68+
# ============================================================================
69+
70+
71+
class SubQuestions(BaseModel):
72+
"""Model for breaking down a question into sub-questions."""
73+
74+
sub_questions: list[str]
75+
76+
77+
class Answer(BaseModel):
78+
"""Model for a research answer."""
79+
80+
answer: str
81+
confidence: float
82+
83+
84+
# ============================================================================
85+
# Agents
86+
# ============================================================================
87+
88+
# Agent that breaks down complex questions into simpler sub-questions
89+
question_breaker_agent = Agent(
90+
'openai:gpt-5-mini',
91+
name='question_breaker',
92+
instructions=(
93+
'You are an expert at breaking down complex questions into simpler, '
94+
'more focused sub-questions that can be researched independently. '
95+
'Create questions that cover different aspects of the original question.'
96+
),
97+
output_type=SubQuestions,
98+
)
99+
100+
# Agent that researches individual questions
101+
researcher_agent = Agent(
102+
'openai:gpt-5-mini',
103+
name='researcher',
104+
instructions=(
105+
'You are a research assistant. Provide clear, accurate, and concise answers '
106+
'to questions based on your knowledge. Include confidence level in your response.'
107+
),
108+
output_type=Answer,
109+
)
110+
111+
# Agent that synthesizes multiple answers into a comprehensive final answer
112+
synthesizer_agent = Agent(
113+
'openai:gpt-5-mini',
114+
name='synthesizer',
115+
instructions=(
116+
'You are an expert at synthesizing multiple pieces of information into '
117+
'a coherent, comprehensive answer. Combine the provided answers while '
118+
'maintaining accuracy and clarity.'
119+
),
120+
)
121+
122+
# Wrap all agents with TemporalAgent for durable execution
123+
temporal_question_breaker = TemporalAgent(question_breaker_agent)
124+
temporal_researcher = TemporalAgent(researcher_agent)
125+
temporal_synthesizer = TemporalAgent(synthesizer_agent)
126+
127+
128+
# ============================================================================
129+
# Graph Definition using Beta API
130+
# ============================================================================
131+
132+
# Create the graph builder
133+
g = GraphBuilder(
134+
name='research_workflow',
135+
state_type=ResearchState,
136+
deps_type=ResearchDeps,
137+
input_type=str, # Takes a question string as input
138+
output_type=str, # Returns final answer as string
139+
auto_instrument=True,
140+
)
141+
142+
143+
# Step 1: Break down the question into sub-questions
144+
@g.step(node_id='break_down_question', label='Break Down Question')
145+
async def break_down_question(
146+
ctx: StepContext[ResearchState, ResearchDeps, str],
147+
) -> ResearchState:
148+
"""Break down the original question into sub-questions using an agent."""
149+
question = ctx.inputs
150+
151+
# Use the TemporalAgent to break down the question
152+
result = await temporal_question_breaker.run(
153+
f'Break down this question into {ctx.deps.max_sub_questions} simpler sub-questions: {question}',
154+
)
155+
156+
# Update state with sub-questions
157+
return ResearchState(
158+
original_question=question,
159+
sub_questions=result.output.sub_questions,
160+
)
161+
162+
163+
# Step 2: Research each sub-question (will run in parallel via map)
164+
@g.step(node_id='research_sub_question', label='Research Sub-Question')
165+
async def research_sub_question(
166+
ctx: StepContext[ResearchState, ResearchDeps, str],
167+
) -> str:
168+
"""Research a single sub-question using an agent."""
169+
sub_question = ctx.inputs
170+
171+
# Use the TemporalAgent to research the sub-question
172+
result = await temporal_researcher.run(sub_question)
173+
174+
# Return the answer as a formatted string
175+
return f'**Q: {sub_question}**\nA: {result.output.answer} (Confidence: {result.output.confidence:.0%})'
176+
177+
178+
# Step 3: Join all research results
179+
research_join = g.join(
180+
reducer=reduce_list_extend,
181+
initial=list[str](),
182+
)
183+
184+
185+
# Step 4: Synthesize all answers into a final answer
186+
@g.step(node_id='synthesize_answer', label='Synthesize Answer')
187+
async def synthesize_answer(
188+
ctx: StepContext[ResearchState, ResearchDeps, list[str]],
189+
) -> ResearchState:
190+
"""Synthesize all research results into a final comprehensive answer."""
191+
research_results = ctx.inputs
192+
193+
# Format the research results for the synthesizer
194+
research_summary = '\n\n'.join(research_results)
195+
196+
# Use the TemporalAgent to synthesize the final answer
197+
result = await temporal_synthesizer.run(
198+
f'Original question: {ctx.state.original_question}\n\n'
199+
f'Research findings:\n{research_summary}\n\n'
200+
'Please synthesize these findings into a comprehensive answer to the original question.',
201+
)
202+
203+
# Update state with final answer
204+
state = ctx.state
205+
state.sub_answers = research_results
206+
state.final_answer = result.output
207+
208+
return state
209+
210+
211+
# Build the graph with edges
212+
g.add(
213+
# Start -> Break down question
214+
g.edge_from(g.start_node).to(break_down_question),
215+
# Break down -> Map over sub-questions for parallel research
216+
g.edge_from(break_down_question)
217+
.transform(lambda ctx: ctx.inputs.sub_questions or [])
218+
.map()
219+
.to(research_sub_question),
220+
# Research results -> Join
221+
g.edge_from(research_sub_question).to(research_join),
222+
# Join -> Synthesize
223+
g.edge_from(research_join).to(synthesize_answer),
224+
# Synthesize -> End
225+
g.edge_from(synthesize_answer)
226+
.transform(lambda ctx: ctx.inputs.final_answer or '')
227+
.to(g.end_node),
228+
)
229+
230+
# Build the final graph
231+
research_graph = g.build()
232+
233+
234+
# ============================================================================
235+
# Temporal Workflow
236+
# ============================================================================
237+
238+
239+
@workflow.defn
240+
class ResearchWorkflow:
241+
"""Temporal workflow that executes the research graph with durable execution."""
242+
243+
@workflow.run
244+
async def run(self, question: str, deps: ResearchDeps | None = None) -> str:
245+
"""Run the research workflow on a question.
246+
247+
Args:
248+
question: The question to research
249+
deps: Optional dependencies for the workflow
250+
251+
Returns:
252+
The final synthesized answer
253+
"""
254+
if deps is None:
255+
deps = ResearchDeps()
256+
257+
# Execute the pydantic-graph graph - it "just works" in Temporal!
258+
result = await research_graph.run(
259+
state=ResearchState(original_question=question),
260+
deps=deps,
261+
inputs=question,
262+
)
263+
264+
return result
265+
266+
267+
# ============================================================================
268+
# Main Execution
269+
# ============================================================================
270+
271+
272+
async def main():
273+
"""Main function to set up worker and execute the workflow."""
274+
# Monkeypatch uuid.uuid4 to use Temporal's deterministic UUID generation
275+
# This is necessary because pydantic-graph uses uuid.uuid4 internally for task IDs
276+
# Connect to Temporal server
277+
client = await Client.connect(
278+
'localhost:7233',
279+
plugins=[PydanticAIPlugin()],
280+
)
281+
282+
# Create a worker that will execute workflows and activities
283+
async with Worker(
284+
client,
285+
task_queue='research',
286+
workflows=[ResearchWorkflow],
287+
plugins=[
288+
# Register activities for all three temporal agents
289+
AgentPlugin(temporal_question_breaker),
290+
AgentPlugin(temporal_researcher),
291+
AgentPlugin(temporal_synthesizer),
292+
],
293+
):
294+
# Execute the workflow
295+
question = 'What are the key factors that contributed to the success of the Apollo 11 moon landing?'
296+
297+
print(f'\n{"=" * 80}')
298+
print(f'Research Question: {question}')
299+
print(f'{"=" * 80}\n')
300+
301+
output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType]
302+
ResearchWorkflow.run,
303+
args=[question],
304+
id=f'research-{uuid.uuid4()}',
305+
task_queue='research',
306+
)
307+
308+
print(f'\n{"=" * 80}')
309+
print('Final Answer:')
310+
print(f'{"=" * 80}\n')
311+
print(output)
312+
print(f'\n{"=" * 80}\n')
313+
314+
315+
if __name__ == '__main__':
316+
import logfire
317+
318+
logfire.instrument_pydantic_ai()
319+
logfire.configure(send_to_logfire=False)
320+
321+
asyncio.run(main())

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
88

99
from pydantic_graph import BaseNode, End, GraphRunContext
10-
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
10+
from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTaskRequest, JoinItem
1111
from pydantic_graph.beta.step import NodeStep
1212

1313
from . import (
@@ -181,7 +181,7 @@ async def __anext__(
181181
return self._task_to_node(task)
182182

183183
def _task_to_node(
184-
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
184+
self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTaskRequest]
185185
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
186186
if isinstance(task, Sequence) and len(task) == 1:
187187
first_task = task[0]
@@ -197,8 +197,8 @@ def _task_to_node(
197197
return End(task.value)
198198
raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
199199

200-
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
201-
return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
200+
def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTaskRequest:
201+
return GraphTaskRequest(NodeStep(type(node)).id, inputs=node, fork_stack=())
202202

203203
async def next(
204204
self,

0 commit comments

Comments
 (0)