From 8fdc82ecb00703cfdebab38fd6e723826e0fa5af Mon Sep 17 00:00:00 2001 From: quanglm Date: Fri, 3 Oct 2025 22:12:28 +0700 Subject: [PATCH 1/3] fix(persistence): deep-copy snapshot in load_next to keep history immutable --- .../pydantic_graph/persistence/in_mem.py | 2 +- tests/graph/test_persistence.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index 85e1e7e03d..5ed3022195 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -143,7 +143,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None): snapshot.status = 'pending' - return snapshot + return copy.deepcopy(snapshot) async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return self.history diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 182fb20f15..efd4a01b3a 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -2,6 +2,7 @@ from __future__ import annotations as _annotations import json +from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timezone @@ -345,3 +346,34 @@ async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]): def test_snapshot_type_adapter_error(): with pytest.raises(RuntimeError, match='Unable to build a Pydantic schema for `BaseNode` without setting'): build_snapshot_list_type_adapter(int, int) + + +async def test_snapshot_state_stability(): + @dataclass + class CountDownState: + counter: int + + @dataclass + class CountDown(BaseNode[CountDownState, None, int]): + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: + if ctx.state.counter <= 0: + return End(ctx.state.counter) + ctx.state.counter -= 1 + return CountDown() + + persistence = FullStatePersistence() + state = CountDownState(counter=3) + count_down_graph = Graph(nodes=[CountDown]) + + await count_down_graph.initialize(CountDown(), state=state, persistence=persistence) + + done = False + while not done: + history = deepcopy(persistence.history) + async with count_down_graph.iter_from_persistence(persistence) as run: + result = await run.next() + done = isinstance(result, End) + + for i in range(len(history)): + assert history[i].id == persistence.history[i].id + assert history[i].state == persistence.history[i].state, 'State should not change' From d67b96079a877fdcdd3088890408f3ff48a11252 Mon Sep 17 00:00:00 2001 From: quanglm Date: Sat, 4 Oct 2025 10:31:25 +0700 Subject: [PATCH 2/3] fix(persistence): deepcopy snapshot in SimpleStatePersistence.load_next --- .../pydantic_graph/persistence/in_mem.py | 2 +- tests/graph/test_persistence.py | 29 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index 5ed3022195..828efba809 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -76,7 +76,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created': self.last_snapshot.status = 'pending' - return self.last_snapshot + return copy.deepcopy(self.last_snapshot) async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: raise NotImplementedError('load is not supported for SimpleStatePersistence') diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index efd4a01b3a..1cf752afda 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -348,7 +348,7 @@ def test_snapshot_type_adapter_error(): build_snapshot_list_type_adapter(int, int) -async def test_snapshot_state_stability(): +async def test_full_state_persistence_snapshot_state_stability(): @dataclass class CountDownState: counter: int @@ -377,3 +377,30 @@ async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int for i in range(len(history)): assert history[i].id == persistence.history[i].id assert history[i].state == persistence.history[i].state, 'State should not change' + + +async def test_simple_state_persistence_snapshot_state_stability(): + @dataclass + class CountDownState: + counter: int + + @dataclass + class CountDown(BaseNode[CountDownState, None, int]): + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: + if ctx.state.counter <= 0: + return End(ctx.state.counter) + ctx.state.counter -= 1 + return CountDown() + + persistence = SimpleStatePersistence() + state = CountDownState(counter=3) + count_down_graph = Graph(nodes=[CountDown]) + + await count_down_graph.initialize(CountDown(), state=state, persistence=persistence) + + last_snapshot = persistence.last_snapshot + async with count_down_graph.iter_from_persistence(persistence) as run: + await run.next() + + assert last_snapshot and last_snapshot.state.counter == 3 + assert persistence.last_snapshot and persistence.last_snapshot.state.counter == 2 From 573c808ba44fbb9b57fa37b3b206b0fe72fdee71 Mon Sep 17 00:00:00 2001 From: quanglm Date: Mon, 6 Oct 2025 14:17:37 +0700 Subject: [PATCH 3/3] fix coverage, remove unnecessary condition --- tests/graph/test_persistence.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 1cf752afda..5527aa8115 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -387,8 +387,6 @@ class CountDownState: @dataclass class CountDown(BaseNode[CountDownState, None, int]): async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: - if ctx.state.counter <= 0: - return End(ctx.state.counter) ctx.state.counter -= 1 return CountDown()