Skip to content

Commit c29d5be

Browse files
committed
add postgres persistence support
1 parent d578b42 commit c29d5be

File tree

1 file changed

+284
-0
lines changed

1 file changed

+284
-0
lines changed
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from contextlib import asynccontextmanager
5+
from datetime import datetime, timezone
6+
from typing import Any, Generic, TypeVar
7+
8+
from pydantic import TypeAdapter
9+
10+
from pydantic_graph import End
11+
from pydantic_graph.exceptions import GraphRuntimeError
12+
from pydantic_graph.nodes import BaseNode
13+
from pydantic_graph.persistence import (
14+
BaseStatePersistence,
15+
EndSnapshot,
16+
NodeSnapshot,
17+
Snapshot,
18+
build_snapshot_list_type_adapter,
19+
)
20+
21+
StateT = TypeVar('StateT')
22+
RunEndT = TypeVar('RunEndT')
23+
24+
25+
class PostgresStatePersistence(
26+
BaseStatePersistence[StateT, RunEndT],
27+
Generic[StateT, RunEndT],
28+
):
29+
"""PostgreSQL-backed implementation of state persistence for graph runs.
30+
31+
Stores full snapshots (NodeSnapshot / EndSnapshot) as JSONB and tracks status and timings.
32+
"""
33+
34+
def __init__(self, pool: Any, run_id: str) -> None:
35+
self.pool = pool
36+
self.run_id = run_id
37+
self._snapshot_adapter: TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = None
38+
39+
def should_set_types(self) -> bool:
40+
return self._snapshot_adapter is None
41+
42+
def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
43+
self._snapshot_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)
44+
45+
def _dump_snapshot(self, snapshot: Snapshot[StateT, RunEndT]) -> dict[str, Any]:
46+
"""Encode a single snapshot to a JSON-serializable dict using the list adapter."""
47+
assert self._snapshot_adapter is not None, 'Persistence types not set'
48+
return self._snapshot_adapter.dump_python([snapshot], mode='json')[0]
49+
50+
def _load_snapshot(self, data: dict[str, Any]) -> Snapshot[StateT, RunEndT]:
51+
"""Decode a single snapshot from a JSON-compatible dictionary."""
52+
assert self._snapshot_adapter is not None, 'Persistence types not set'
53+
54+
if isinstance(data, str):
55+
data = json.loads(data)
56+
57+
return self._snapshot_adapter.validate_python([data])[0]
58+
59+
async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
60+
"""Snapshot a node when it is scheduled by the graph."""
61+
snapshot = NodeSnapshot(state=state, node=next_node)
62+
payload = self._dump_snapshot(snapshot)
63+
node_id = next_node.get_node_def(local_ns=None).node_id
64+
65+
async with self.pool.acquire() as conn:
66+
await conn.execute(
67+
"""
68+
INSERT INTO graph_snapshots (
69+
run_id, snapshot_id, kind, status, node_id, snapshot
70+
)
71+
VALUES ($1, $2, 'node', 'queued', $3, $4::jsonb)
72+
ON CONFLICT (run_id, snapshot_id) DO UPDATE
73+
SET snapshot = EXCLUDED.snapshot,
74+
node_id = EXCLUDED.node_id,
75+
status = EXCLUDED.status
76+
""",
77+
self.run_id,
78+
snapshot.id,
79+
node_id,
80+
json.dumps(payload),
81+
)
82+
83+
async def snapshot_node_if_new(
84+
self,
85+
snapshot_id: str,
86+
state: StateT,
87+
next_node: BaseNode[StateT, Any, RunEndT],
88+
) -> None:
89+
"""Snapshot a node only if the given snapshot_id does not already exist."""
90+
snapshot = NodeSnapshot(state=state, node=next_node, id=snapshot_id)
91+
payload = self._dump_snapshot(snapshot)
92+
node_id = next_node.get_node_def(local_ns=None).node_id
93+
94+
async with self.pool.acquire() as conn:
95+
await conn.execute(
96+
"""
97+
INSERT INTO graph_snapshots (
98+
run_id, snapshot_id, kind, status, node_id, snapshot
99+
)
100+
VALUES ($1, $2, 'node', 'queued', $3, $4::jsonb)
101+
ON CONFLICT (run_id, snapshot_id) DO NOTHING
102+
""",
103+
self.run_id,
104+
snapshot_id,
105+
node_id,
106+
json.dumps(payload),
107+
)
108+
109+
async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
110+
"""Snapshot the graph end state and update the run result."""
111+
snapshot = EndSnapshot(state=state, result=end)
112+
payload = self._dump_snapshot(snapshot)
113+
114+
async with self.pool.acquire() as conn:
115+
await conn.execute(
116+
"""
117+
INSERT INTO graph_snapshots (
118+
run_id, snapshot_id, kind, status, node_id, snapshot
119+
)
120+
VALUES ($1, $2, 'end', NULL, 'End', $3::jsonb)
121+
ON CONFLICT (run_id, snapshot_id) DO UPDATE
122+
SET snapshot = EXCLUDED.snapshot
123+
""",
124+
self.run_id,
125+
snapshot.id,
126+
json.dumps(payload),
127+
)
128+
129+
await conn.execute(
130+
"""
131+
UPDATE graph_runs
132+
SET finished_at = now(),
133+
status = 'success',
134+
result = $2::jsonb
135+
WHERE id = $1
136+
""",
137+
self.run_id,
138+
json.dumps(payload),
139+
)
140+
141+
@asynccontextmanager
142+
async def record_run(self, snapshot_id: str):
143+
"""Record execution status and timing for a single node run.
144+
145+
Called by Graph around the actual execution of a node.
146+
147+
We:
148+
- assert the node is not already running or finished
149+
- mark it as running
150+
- on success or error, update status and timing
151+
"""
152+
async with self.pool.acquire() as conn:
153+
row = await conn.fetchrow(
154+
"""
155+
SELECT status
156+
FROM graph_snapshots
157+
WHERE run_id = $1
158+
AND snapshot_id = $2
159+
AND kind = 'node'
160+
""",
161+
self.run_id,
162+
snapshot_id,
163+
)
164+
if row is None:
165+
raise LookupError(f'Snapshot {snapshot_id!r} not found for run {self.run_id}')
166+
167+
current_status = row['status']
168+
if current_status not in ('queued', 'pending'):
169+
raise GraphRuntimeError(
170+
f'Snapshot {snapshot_id!r} already in status {current_status!r}',
171+
)
172+
173+
now = datetime.now(timezone.utc)
174+
await conn.execute(
175+
"""
176+
UPDATE graph_snapshots
177+
SET status = 'running',
178+
started_at = $3
179+
WHERE run_id = $1 AND snapshot_id = $2
180+
""",
181+
self.run_id,
182+
snapshot_id,
183+
now,
184+
)
185+
186+
start = datetime.now(timezone.utc)
187+
188+
try:
189+
yield
190+
except Exception:
191+
duration = (datetime.now(timezone.utc) - start).total_seconds()
192+
async with self.pool.acquire() as conn:
193+
await conn.execute(
194+
"""
195+
UPDATE graph_snapshots
196+
SET status = 'error',
197+
finished_at = now(),
198+
duration_secs = $3
199+
WHERE run_id = $1 AND snapshot_id = $2
200+
""",
201+
self.run_id,
202+
snapshot_id,
203+
duration,
204+
)
205+
await conn.execute(
206+
"""
207+
UPDATE graph_runs
208+
SET status = 'error',
209+
finished_at = COALESCE(finished_at, now())
210+
WHERE id = $1
211+
""",
212+
self.run_id,
213+
)
214+
raise
215+
else:
216+
duration = (datetime.now(timezone.utc) - start).total_seconds()
217+
async with self.pool.acquire() as conn:
218+
await conn.execute(
219+
"""
220+
UPDATE graph_snapshots
221+
SET status = 'success',
222+
finished_at = now(),
223+
duration_secs = $3
224+
WHERE run_id = $1 AND snapshot_id = $2
225+
""",
226+
self.run_id,
227+
snapshot_id,
228+
duration,
229+
)
230+
231+
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
232+
"""Pop the next queued or pending node snapshot for this run."""
233+
async with self.pool.acquire() as conn:
234+
async with conn.transaction():
235+
row = await conn.fetchrow(
236+
"""
237+
SELECT snapshot_id, snapshot
238+
FROM graph_snapshots
239+
WHERE run_id = $1
240+
AND kind = 'node'
241+
AND status IN ('queued', 'pending')
242+
ORDER BY created_at
243+
FOR UPDATE SKIP LOCKED
244+
LIMIT 1
245+
""",
246+
self.run_id,
247+
)
248+
if row is None:
249+
return None
250+
251+
snapshot_id = row['snapshot_id']
252+
await conn.execute(
253+
"""
254+
UPDATE graph_snapshots
255+
SET status = 'pending'
256+
WHERE run_id = $1
257+
AND snapshot_id = $2
258+
""",
259+
self.run_id,
260+
snapshot_id,
261+
)
262+
263+
snapshot = self._load_snapshot(row['snapshot'])
264+
if not isinstance(snapshot, NodeSnapshot):
265+
raise TypeError(f'Expected NodeSnapshot, got {type(snapshot)}')
266+
return snapshot
267+
268+
async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
269+
"""Load all snapshots for this run in creation order."""
270+
async with self.pool.acquire() as conn:
271+
rows = await conn.fetch(
272+
"""
273+
SELECT snapshot
274+
FROM graph_snapshots
275+
WHERE run_id = $1
276+
ORDER BY created_at
277+
""",
278+
self.run_id,
279+
)
280+
281+
raw_payloads = [r['snapshot'] for r in rows]
282+
payloads = [json.loads(p) if isinstance(p, str) else p for p in raw_payloads]
283+
assert self._snapshot_adapter is not None, 'Persistence types not set'
284+
return self._snapshot_adapter.validate_python(payloads)

0 commit comments

Comments
 (0)