Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pydantic_graph/pydantic_graph/persistence/in_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
snapshot.status = 'pending'
return snapshot
return copy.deepcopy(snapshot)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this in SimpleStatePersistence above as well


async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
return self.history
Expand Down
32 changes: 32 additions & 0 deletions tests/graph/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations as _annotations

import json
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timezone

Expand Down Expand Up @@ -345,3 +346,34 @@ async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]):
def test_snapshot_type_adapter_error():
with pytest.raises(RuntimeError, match='Unable to build a Pydantic schema for `BaseNode` without setting'):
build_snapshot_list_type_adapter(int, int)


async def test_snapshot_state_stability():
@dataclass
class CountDownState:
counter: int

@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
if ctx.state.counter <= 0:
return End(ctx.state.counter)
ctx.state.counter -= 1
return CountDown()

persistence = FullStatePersistence()
state = CountDownState(counter=3)
count_down_graph = Graph(nodes=[CountDown])

await count_down_graph.initialize(CountDown(), state=state, persistence=persistence)

done = False
while not done:
history = deepcopy(persistence.history)
async with count_down_graph.iter_from_persistence(persistence) as run:
result = await run.next()
done = isinstance(result, End)

for i in range(len(history)):
assert history[i].id == persistence.history[i].id
assert history[i].state == persistence.history[i].state, 'State should not change'