Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pydantic_graph/pydantic_graph/persistence/in_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created':
self.last_snapshot.status = 'pending'
return self.last_snapshot
return copy.deepcopy(self.last_snapshot)

async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
raise NotImplementedError('load is not supported for SimpleStatePersistence')
Expand Down Expand Up @@ -143,7 +143,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
snapshot.status = 'pending'
return snapshot
return copy.deepcopy(snapshot)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this in SimpleStatePersistence above as well


async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
return self.history
Expand Down
57 changes: 57 additions & 0 deletions tests/graph/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations as _annotations

import json
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timezone

Expand Down Expand Up @@ -345,3 +346,59 @@ async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]):
def test_snapshot_type_adapter_error():
with pytest.raises(RuntimeError, match='Unable to build a Pydantic schema for `BaseNode` without setting'):
build_snapshot_list_type_adapter(int, int)


async def test_full_state_persistence_snapshot_state_stability():
@dataclass
class CountDownState:
counter: int

@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
if ctx.state.counter <= 0:
return End(ctx.state.counter)
ctx.state.counter -= 1
return CountDown()

persistence = FullStatePersistence()
state = CountDownState(counter=3)
count_down_graph = Graph(nodes=[CountDown])

await count_down_graph.initialize(CountDown(), state=state, persistence=persistence)

done = False
while not done:
history = deepcopy(persistence.history)
async with count_down_graph.iter_from_persistence(persistence) as run:
result = await run.next()
done = isinstance(result, End)

for i in range(len(history)):
assert history[i].id == persistence.history[i].id
assert history[i].state == persistence.history[i].state, 'State should not change'


async def test_simple_state_persistence_snapshot_state_stability():
@dataclass
class CountDownState:
counter: int

@dataclass
class CountDown(BaseNode[CountDownState, None, int]):
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
ctx.state.counter -= 1
return CountDown()

persistence = SimpleStatePersistence()
state = CountDownState(counter=3)
count_down_graph = Graph(nodes=[CountDown])

await count_down_graph.initialize(CountDown(), state=state, persistence=persistence)

last_snapshot = persistence.last_snapshot
async with count_down_graph.iter_from_persistence(persistence) as run:
await run.next()

assert last_snapshot and last_snapshot.state.counter == 3
assert persistence.last_snapshot and persistence.last_snapshot.state.counter == 2