Skip to content

Commit a067def

Browse files
Ensure graph persistence snapshots are not mutated when run is resumed (#3077)
1 parent 9b1913e commit a067def

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

pydantic_graph/pydantic_graph/persistence/in_mem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
7676
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
7777
if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created':
7878
self.last_snapshot.status = 'pending'
79-
return self.last_snapshot
79+
return copy.deepcopy(self.last_snapshot)
8080

8181
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
8282
raise NotImplementedError('load is not supported for SimpleStatePersistence')
@@ -143,7 +143,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
143143
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
144144
if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
145145
snapshot.status = 'pending'
146-
return snapshot
146+
return copy.deepcopy(snapshot)
147147

148148
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
149149
return self.history

tests/graph/test_persistence.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations as _annotations
33

44
import json
5+
from copy import deepcopy
56
from dataclasses import dataclass
67
from datetime import datetime, timezone
78

@@ -345,3 +346,59 @@ async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]):
345346
def test_snapshot_type_adapter_error():
346347
with pytest.raises(RuntimeError, match='Unable to build a Pydantic schema for `BaseNode` without setting'):
347348
build_snapshot_list_type_adapter(int, int)
349+
350+
351+
async def test_full_state_persistence_snapshot_state_stability():
352+
@dataclass
353+
class CountDownState:
354+
counter: int
355+
356+
@dataclass
357+
class CountDown(BaseNode[CountDownState, None, int]):
358+
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
359+
if ctx.state.counter <= 0:
360+
return End(ctx.state.counter)
361+
ctx.state.counter -= 1
362+
return CountDown()
363+
364+
persistence = FullStatePersistence()
365+
state = CountDownState(counter=3)
366+
count_down_graph = Graph(nodes=[CountDown])
367+
368+
await count_down_graph.initialize(CountDown(), state=state, persistence=persistence)
369+
370+
done = False
371+
while not done:
372+
history = deepcopy(persistence.history)
373+
async with count_down_graph.iter_from_persistence(persistence) as run:
374+
result = await run.next()
375+
done = isinstance(result, End)
376+
377+
for i in range(len(history)):
378+
assert history[i].id == persistence.history[i].id
379+
assert history[i].state == persistence.history[i].state, 'State should not change'
380+
381+
382+
async def test_simple_state_persistence_snapshot_state_stability():
383+
@dataclass
384+
class CountDownState:
385+
counter: int
386+
387+
@dataclass
388+
class CountDown(BaseNode[CountDownState, None, int]):
389+
async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]:
390+
ctx.state.counter -= 1
391+
return CountDown()
392+
393+
persistence = SimpleStatePersistence()
394+
state = CountDownState(counter=3)
395+
count_down_graph = Graph(nodes=[CountDown])
396+
397+
await count_down_graph.initialize(CountDown(), state=state, persistence=persistence)
398+
399+
last_snapshot = persistence.last_snapshot
400+
async with count_down_graph.iter_from_persistence(persistence) as run:
401+
await run.next()
402+
403+
assert last_snapshot and last_snapshot.state.counter == 3
404+
assert persistence.last_snapshot and persistence.last_snapshot.state.counter == 2

0 commit comments

Comments
 (0)