|
2 | 2 | from __future__ import annotations as _annotations
|
3 | 3 |
|
4 | 4 | import json
|
| 5 | +from copy import deepcopy |
5 | 6 | from dataclasses import dataclass
|
6 | 7 | from datetime import datetime, timezone
|
7 | 8 |
|
@@ -345,3 +346,59 @@ async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]):
|
345 | 346 | def test_snapshot_type_adapter_error():
|
346 | 347 | with pytest.raises(RuntimeError, match='Unable to build a Pydantic schema for `BaseNode` without setting'):
|
347 | 348 | build_snapshot_list_type_adapter(int, int)
|
| 349 | + |
| 350 | + |
| 351 | +async def test_full_state_persistence_snapshot_state_stability(): |
| 352 | + @dataclass |
| 353 | + class CountDownState: |
| 354 | + counter: int |
| 355 | + |
| 356 | + @dataclass |
| 357 | + class CountDown(BaseNode[CountDownState, None, int]): |
| 358 | + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: |
| 359 | + if ctx.state.counter <= 0: |
| 360 | + return End(ctx.state.counter) |
| 361 | + ctx.state.counter -= 1 |
| 362 | + return CountDown() |
| 363 | + |
| 364 | + persistence = FullStatePersistence() |
| 365 | + state = CountDownState(counter=3) |
| 366 | + count_down_graph = Graph(nodes=[CountDown]) |
| 367 | + |
| 368 | + await count_down_graph.initialize(CountDown(), state=state, persistence=persistence) |
| 369 | + |
| 370 | + done = False |
| 371 | + while not done: |
| 372 | + history = deepcopy(persistence.history) |
| 373 | + async with count_down_graph.iter_from_persistence(persistence) as run: |
| 374 | + result = await run.next() |
| 375 | + done = isinstance(result, End) |
| 376 | + |
| 377 | + for i in range(len(history)): |
| 378 | + assert history[i].id == persistence.history[i].id |
| 379 | + assert history[i].state == persistence.history[i].state, 'State should not change' |
| 380 | + |
| 381 | + |
| 382 | +async def test_simple_state_persistence_snapshot_state_stability(): |
| 383 | + @dataclass |
| 384 | + class CountDownState: |
| 385 | + counter: int |
| 386 | + |
| 387 | + @dataclass |
| 388 | + class CountDown(BaseNode[CountDownState, None, int]): |
| 389 | + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: |
| 390 | + ctx.state.counter -= 1 |
| 391 | + return CountDown() |
| 392 | + |
| 393 | + persistence = SimpleStatePersistence() |
| 394 | + state = CountDownState(counter=3) |
| 395 | + count_down_graph = Graph(nodes=[CountDown]) |
| 396 | + |
| 397 | + await count_down_graph.initialize(CountDown(), state=state, persistence=persistence) |
| 398 | + |
| 399 | + last_snapshot = persistence.last_snapshot |
| 400 | + async with count_down_graph.iter_from_persistence(persistence) as run: |
| 401 | + await run.next() |
| 402 | + |
| 403 | + assert last_snapshot and last_snapshot.state.counter == 3 |
| 404 | + assert persistence.last_snapshot and persistence.last_snapshot.state.counter == 2 |
0 commit comments