Skip to content

Commit 416a0d1

Browse files
State persistence (#955)
Co-authored-by: David Montague <[email protected]>
1 parent 7970e82 commit 416a0d1

File tree

29 files changed

+2051
-983
lines changed

29 files changed

+2051
-983
lines changed

Makefile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ test: ## Run tests and collect coverage data
4949

5050
.PHONY: test-all-python
5151
test-all-python: ## Run tests on Python 3.9 to 3.13
52-
UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 --all-extras coverage run -p -m pytest
53-
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras coverage run -p -m pytest
54-
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras coverage run -p -m pytest
55-
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras coverage run -p -m pytest
56-
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras coverage run -p -m pytest
52+
UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 --all-extras --all-packages coverage run -p -m pytest
53+
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras --all-packages coverage run -p -m pytest
54+
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras --all-packages coverage run -p -m pytest
55+
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras --all-packages coverage run -p -m pytest
56+
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras --all-packages coverage run -p -m pytest
5757
@uv run coverage combine
5858
@uv run coverage report
5959

docs/api/pydantic_graph/nodes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
::: pydantic_graph.nodes
44
options:
55
members:
6+
- StateT
67
- GraphRunContext
78
- BaseNode
89
- End
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# `pydantic_graph.persistence`
2+
3+
::: pydantic_graph.persistence
4+
5+
::: pydantic_graph.persistence.in_mem
6+
7+
::: pydantic_graph.persistence.file

docs/api/pydantic_graph/state.md

Lines changed: 0 additions & 3 deletions
This file was deleted.

docs/graph.md

Lines changed: 196 additions & 373 deletions
Large diffs are not rendered by default.

examples/pydantic_ai_examples/question_graph.py

Lines changed: 37 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@
99

1010
from dataclasses import dataclass, field
1111
from pathlib import Path
12-
from typing import Annotated
1312

1413
import logfire
15-
from devtools import debug
16-
from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, HistoryStep
14+
from groq import BaseModel
15+
from pydantic_graph import (
16+
BaseNode,
17+
End,
18+
Graph,
19+
GraphRunContext,
20+
)
21+
from pydantic_graph.persistence.file import FileStatePersistence
1722

1823
from pydantic_ai import Agent
1924
from pydantic_ai.format_as_xml import format_as_xml
@@ -41,22 +46,23 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
4146
)
4247
ctx.state.ask_agent_messages += result.all_messages()
4348
ctx.state.question = result.data
44-
return Answer()
49+
return Answer(result.data)
4550

4651

4752
@dataclass
4853
class Answer(BaseNode[QuestionState]):
49-
answer: str | None = None
54+
question: str
5055

5156
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
52-
assert self.answer is not None
53-
return Evaluate(self.answer)
57+
answer = input(f'{self.question}: ')
58+
return Evaluate(answer)
5459

5560

56-
@dataclass
57-
class EvaluationResult:
61+
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
5862
correct: bool
63+
"""Whether the answer is correct."""
5964
comment: str
65+
"""Comment on the answer, reprimand the user if the answer is wrong."""
6066

6167

6268
evaluate_agent = Agent(
@@ -67,101 +73,76 @@ class EvaluationResult:
6773

6874

6975
@dataclass
70-
class Evaluate(BaseNode[QuestionState]):
76+
class Evaluate(BaseNode[QuestionState, None, str]):
7177
answer: str
7278

7379
async def run(
7480
self,
7581
ctx: GraphRunContext[QuestionState],
76-
) -> Congratulate | Reprimand:
82+
) -> End[str] | Reprimand:
7783
assert ctx.state.question is not None
7884
result = await evaluate_agent.run(
7985
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
8086
message_history=ctx.state.evaluate_agent_messages,
8187
)
8288
ctx.state.evaluate_agent_messages += result.all_messages()
8389
if result.data.correct:
84-
return Congratulate(result.data.comment)
90+
return End(result.data.comment)
8591
else:
8692
return Reprimand(result.data.comment)
8793

8894

89-
@dataclass
90-
class Congratulate(BaseNode[QuestionState, None, None]):
91-
comment: str
92-
93-
async def run(
94-
self, ctx: GraphRunContext[QuestionState]
95-
) -> Annotated[End, Edge(label='success')]:
96-
print(f'Correct answer! {self.comment}')
97-
return End(None)
98-
99-
10095
@dataclass
10196
class Reprimand(BaseNode[QuestionState]):
10297
comment: str
10398

10499
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
105100
print(f'Comment: {self.comment}')
106-
# > Comment: Vichy is no longer the capital of France.
107101
ctx.state.question = None
108102
return Ask()
109103

110104

111105
question_graph = Graph(
112-
nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand), state_type=QuestionState
106+
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
113107
)
114108

115109

116110
async def run_as_continuous():
117111
state = QuestionState()
118112
node = Ask()
119-
history: list[HistoryStep[QuestionState, None]] = []
120-
with logfire.span('run questions graph'):
121-
while True:
122-
node = await question_graph.next(node, history, state=state)
123-
if isinstance(node, End):
124-
debug([e.data_snapshot() for e in history])
125-
break
126-
elif isinstance(node, Answer):
127-
assert state.question
128-
node.answer = input(f'{state.question} ')
129-
# otherwise just continue
113+
end = await question_graph.run(node, state=state)
114+
print('END:', end.output)
130115

131116

132117
async def run_as_cli(answer: str | None):
133-
history_file = Path('question_graph_history.json')
134-
history = (
135-
question_graph.load_history(history_file.read_bytes())
136-
if history_file.exists()
137-
else []
138-
)
139-
140-
if history:
141-
last = history[-1]
142-
assert last.kind == 'node', 'expected last step to be a node'
143-
state = last.state
144-
assert answer is not None, 'answer is required to continue from history'
145-
node = Answer(answer)
118+
persistence = FileStatePersistence(Path('question_graph.json'))
119+
persistence.set_graph_types(question_graph)
120+
121+
if snapshot := await persistence.load_next():
122+
state = snapshot.state
123+
assert answer is not None, (
124+
'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"'
125+
)
126+
node = Evaluate(answer)
146127
else:
147128
state = QuestionState()
148129
node = Ask()
149-
debug(state, node)
130+
# debug(state, node)
150131

151-
with logfire.span('run questions graph'):
132+
async with question_graph.iter(node, state=state, persistence=persistence) as run:
152133
while True:
153-
node = await question_graph.next(node, history, state=state)
134+
node = await run.next()
154135
if isinstance(node, End):
155-
debug([e.data_snapshot() for e in history])
136+
print('END:', node.data)
137+
history = await persistence.load_all()
138+
print('history:', '\n'.join(str(e.node) for e in history), sep='\n')
156139
print('Finished!')
157140
break
158141
elif isinstance(node, Answer):
159-
print(state.question)
142+
print(node.question)
160143
break
161144
# otherwise just continue
162145

163-
history_file.write_bytes(question_graph.dump_history(history, indent=2))
164-
165146

166147
if __name__ == '__main__':
167148
import asyncio

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ nav:
6666
- api/providers.md
6767
- api/pydantic_graph/graph.md
6868
- api/pydantic_graph/nodes.md
69-
- api/pydantic_graph/state.md
69+
- api/pydantic_graph/persistence.md
7070
- api/pydantic_graph/mermaid.md
7171
- api/pydantic_graph/exceptions.md
7272

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,8 @@ async def main():
475475
start_node,
476476
state=state,
477477
deps=graph_deps,
478-
infer_name=False,
479478
span=use_span(run_span, end_on_exit=True),
479+
infer_name=False,
480480
) as graph_run:
481481
yield AgentRun(graph_run)
482482

pydantic_graph/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,4 @@ fives_graph = Graph(nodes=[DivisibleBy5, Increment])
5353
result = fives_graph.run_sync(DivisibleBy5(4))
5454
print(result.output)
5555
#> 5
56-
# the full history is quite verbose (see below), so we'll just print the summary
57-
print([item.data_snapshot() for item in result.history])
58-
#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)]
5956
```
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .exceptions import GraphRuntimeError, GraphSetupError
22
from .graph import Graph, GraphRun, GraphRunResult
33
from .nodes import BaseNode, Edge, End, GraphRunContext
4-
from .state import EndStep, HistoryStep, NodeStep
4+
from .persistence import EndSnapshot, NodeSnapshot, Snapshot
5+
from .persistence.in_mem import FullStatePersistence, SimpleStatePersistence
56

67
__all__ = (
78
'Graph',
@@ -11,9 +12,11 @@
1112
'End',
1213
'GraphRunContext',
1314
'Edge',
14-
'EndStep',
15-
'HistoryStep',
16-
'NodeStep',
15+
'EndSnapshot',
16+
'Snapshot',
17+
'NodeSnapshot',
1718
'GraphSetupError',
1819
'GraphRuntimeError',
20+
'SimpleStatePersistence',
21+
'FullStatePersistence',
1922
)

0 commit comments

Comments
 (0)