1+ from __future__ import annotations
2+
3+ import json
4+ from contextlib import asynccontextmanager
5+ from datetime import datetime , timezone
6+ from typing import Any , Generic , TypeVar
7+
8+ from pydantic import TypeAdapter
9+
10+ from pydantic_graph import End
11+ from pydantic_graph .exceptions import GraphRuntimeError
12+ from pydantic_graph .nodes import BaseNode
13+ from pydantic_graph .persistence import (
14+ BaseStatePersistence ,
15+ EndSnapshot ,
16+ NodeSnapshot ,
17+ Snapshot ,
18+ build_snapshot_list_type_adapter ,
19+ )
20+
21+ StateT = TypeVar ('StateT' )
22+ RunEndT = TypeVar ('RunEndT' )
23+
24+
25+ class PostgresStatePersistence (
26+ BaseStatePersistence [StateT , RunEndT ],
27+ Generic [StateT , RunEndT ],
28+ ):
29+ """PostgreSQL-backed implementation of state persistence for graph runs.
30+
31+ Stores full snapshots (NodeSnapshot / EndSnapshot) as JSONB and tracks status and timings.
32+ """
33+
34+ def __init__ (self , pool : Any , run_id : str ) -> None :
35+ self .pool = pool
36+ self .run_id = run_id
37+ self ._snapshot_adapter : TypeAdapter [list [Snapshot [StateT , RunEndT ]]] | None = None
38+
39+ def should_set_types (self ) -> bool :
40+ return self ._snapshot_adapter is None
41+
42+ def set_types (self , state_type : type [StateT ], run_end_type : type [RunEndT ]) -> None :
43+ self ._snapshot_adapter = build_snapshot_list_type_adapter (state_type , run_end_type )
44+
45+ def _dump_snapshot (self , snapshot : Snapshot [StateT , RunEndT ]) -> dict [str , Any ]:
46+ """Encode a single snapshot to a JSON-serializable dict using the list adapter."""
47+ assert self ._snapshot_adapter is not None , 'Persistence types not set'
48+ return self ._snapshot_adapter .dump_python ([snapshot ], mode = 'json' )[0 ]
49+
50+ def _load_snapshot (self , data : dict [str , Any ]) -> Snapshot [StateT , RunEndT ]:
51+ """Decode a single snapshot from a JSON-compatible dictionary."""
52+ assert self ._snapshot_adapter is not None , 'Persistence types not set'
53+
54+ if isinstance (data , str ):
55+ data = json .loads (data )
56+
57+ return self ._snapshot_adapter .validate_python ([data ])[0 ]
58+
59+ async def snapshot_node (self , state : StateT , next_node : BaseNode [StateT , Any , RunEndT ]) -> None :
60+ """Snapshot a node when it is scheduled by the graph."""
61+ snapshot = NodeSnapshot (state = state , node = next_node )
62+ payload = self ._dump_snapshot (snapshot )
63+ node_id = next_node .get_node_def (local_ns = None ).node_id
64+
65+ async with self .pool .acquire () as conn :
66+ await conn .execute (
67+ """
68+ INSERT INTO graph_snapshots (
69+ run_id, snapshot_id, kind, status, node_id, snapshot
70+ )
71+ VALUES ($1, $2, 'node', 'queued', $3, $4::jsonb)
72+ ON CONFLICT (run_id, snapshot_id) DO UPDATE
73+ SET snapshot = EXCLUDED.snapshot,
74+ node_id = EXCLUDED.node_id,
75+ status = EXCLUDED.status
76+ """ ,
77+ self .run_id ,
78+ snapshot .id ,
79+ node_id ,
80+ json .dumps (payload ),
81+ )
82+
83+ async def snapshot_node_if_new (
84+ self ,
85+ snapshot_id : str ,
86+ state : StateT ,
87+ next_node : BaseNode [StateT , Any , RunEndT ],
88+ ) -> None :
89+ """Snapshot a node only if the given snapshot_id does not already exist."""
90+ snapshot = NodeSnapshot (state = state , node = next_node , id = snapshot_id )
91+ payload = self ._dump_snapshot (snapshot )
92+ node_id = next_node .get_node_def (local_ns = None ).node_id
93+
94+ async with self .pool .acquire () as conn :
95+ await conn .execute (
96+ """
97+ INSERT INTO graph_snapshots (
98+ run_id, snapshot_id, kind, status, node_id, snapshot
99+ )
100+ VALUES ($1, $2, 'node', 'queued', $3, $4::jsonb)
101+ ON CONFLICT (run_id, snapshot_id) DO NOTHING
102+ """ ,
103+ self .run_id ,
104+ snapshot_id ,
105+ node_id ,
106+ json .dumps (payload ),
107+ )
108+
109+ async def snapshot_end (self , state : StateT , end : End [RunEndT ]) -> None :
110+ """Snapshot the graph end state and update the run result."""
111+ snapshot = EndSnapshot (state = state , result = end )
112+ payload = self ._dump_snapshot (snapshot )
113+
114+ async with self .pool .acquire () as conn :
115+ await conn .execute (
116+ """
117+ INSERT INTO graph_snapshots (
118+ run_id, snapshot_id, kind, status, node_id, snapshot
119+ )
120+ VALUES ($1, $2, 'end', NULL, 'End', $3::jsonb)
121+ ON CONFLICT (run_id, snapshot_id) DO UPDATE
122+ SET snapshot = EXCLUDED.snapshot
123+ """ ,
124+ self .run_id ,
125+ snapshot .id ,
126+ json .dumps (payload ),
127+ )
128+
129+ await conn .execute (
130+ """
131+ UPDATE graph_runs
132+ SET finished_at = now(),
133+ status = 'success',
134+ result = $2::jsonb
135+ WHERE id = $1
136+ """ ,
137+ self .run_id ,
138+ json .dumps (payload ),
139+ )
140+
141+ @asynccontextmanager
142+ async def record_run (self , snapshot_id : str ):
143+ """Record execution status and timing for a single node run.
144+
145+ Called by Graph around the actual execution of a node.
146+
147+ We:
148+ - assert the node is not already running or finished
149+ - mark it as running
150+ - on success or error, update status and timing
151+ """
152+ async with self .pool .acquire () as conn :
153+ row = await conn .fetchrow (
154+ """
155+ SELECT status
156+ FROM graph_snapshots
157+ WHERE run_id = $1
158+ AND snapshot_id = $2
159+ AND kind = 'node'
160+ """ ,
161+ self .run_id ,
162+ snapshot_id ,
163+ )
164+ if row is None :
165+ raise LookupError (f'Snapshot { snapshot_id !r} not found for run { self .run_id } ' )
166+
167+ current_status = row ['status' ]
168+ if current_status not in ('queued' , 'pending' ):
169+ raise GraphRuntimeError (
170+ f'Snapshot { snapshot_id !r} already in status { current_status !r} ' ,
171+ )
172+
173+ now = datetime .now (timezone .utc )
174+ await conn .execute (
175+ """
176+ UPDATE graph_snapshots
177+ SET status = 'running',
178+ started_at = $3
179+ WHERE run_id = $1 AND snapshot_id = $2
180+ """ ,
181+ self .run_id ,
182+ snapshot_id ,
183+ now ,
184+ )
185+
186+ start = datetime .now (timezone .utc )
187+
188+ try :
189+ yield
190+ except Exception :
191+ duration = (datetime .now (timezone .utc ) - start ).total_seconds ()
192+ async with self .pool .acquire () as conn :
193+ await conn .execute (
194+ """
195+ UPDATE graph_snapshots
196+ SET status = 'error',
197+ finished_at = now(),
198+ duration_secs = $3
199+ WHERE run_id = $1 AND snapshot_id = $2
200+ """ ,
201+ self .run_id ,
202+ snapshot_id ,
203+ duration ,
204+ )
205+ await conn .execute (
206+ """
207+ UPDATE graph_runs
208+ SET status = 'error',
209+ finished_at = COALESCE(finished_at, now())
210+ WHERE id = $1
211+ """ ,
212+ self .run_id ,
213+ )
214+ raise
215+ else :
216+ duration = (datetime .now (timezone .utc ) - start ).total_seconds ()
217+ async with self .pool .acquire () as conn :
218+ await conn .execute (
219+ """
220+ UPDATE graph_snapshots
221+ SET status = 'success',
222+ finished_at = now(),
223+ duration_secs = $3
224+ WHERE run_id = $1 AND snapshot_id = $2
225+ """ ,
226+ self .run_id ,
227+ snapshot_id ,
228+ duration ,
229+ )
230+
231+ async def load_next (self ) -> NodeSnapshot [StateT , RunEndT ] | None :
232+ """Pop the next queued or pending node snapshot for this run."""
233+ async with self .pool .acquire () as conn :
234+ async with conn .transaction ():
235+ row = await conn .fetchrow (
236+ """
237+ SELECT snapshot_id, snapshot
238+ FROM graph_snapshots
239+ WHERE run_id = $1
240+ AND kind = 'node'
241+ AND status IN ('queued', 'pending')
242+ ORDER BY created_at
243+ FOR UPDATE SKIP LOCKED
244+ LIMIT 1
245+ """ ,
246+ self .run_id ,
247+ )
248+ if row is None :
249+ return None
250+
251+ snapshot_id = row ['snapshot_id' ]
252+ await conn .execute (
253+ """
254+ UPDATE graph_snapshots
255+ SET status = 'pending'
256+ WHERE run_id = $1
257+ AND snapshot_id = $2
258+ """ ,
259+ self .run_id ,
260+ snapshot_id ,
261+ )
262+
263+ snapshot = self ._load_snapshot (row ['snapshot' ])
264+ if not isinstance (snapshot , NodeSnapshot ):
265+ raise TypeError (f'Expected NodeSnapshot, got { type (snapshot )} ' )
266+ return snapshot
267+
268+ async def load_all (self ) -> list [Snapshot [StateT , RunEndT ]]:
269+ """Load all snapshots for this run in creation order."""
270+ async with self .pool .acquire () as conn :
271+ rows = await conn .fetch (
272+ """
273+ SELECT snapshot
274+ FROM graph_snapshots
275+ WHERE run_id = $1
276+ ORDER BY created_at
277+ """ ,
278+ self .run_id ,
279+ )
280+
281+ raw_payloads = [r ['snapshot' ] for r in rows ]
282+ payloads = [json .loads (p ) if isinstance (p , str ) else p for p in raw_payloads ]
283+ assert self ._snapshot_adapter is not None , 'Persistence types not set'
284+ return self ._snapshot_adapter .validate_python (payloads )
0 commit comments