@@ -19,39 +19,44 @@ class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]):
1919 def __init__ (self , store_name : str , key_prefix : str ):
2020 self .store_name = store_name
2121 self .key_prefix = key_prefix
22+ self .client = DaprClient ()
2223
2324 # helper: construct Dapr key for a thread
2425 def _get_key (self , config : RunnableConfig ) -> str :
2526 thread_id = None
27+
2628 if isinstance (config , dict ):
27- thread_id = config .get ('configurable' , {}).get ('thread_id' )
29+ thread_id = config .get ("configurable" , {}).get ("thread_id" )
30+
2831 if not thread_id :
29- thread_id = config .get ('thread_id' )
32+ thread_id = config .get ("thread_id" )
33+
3034 if not thread_id :
31- thread_id = 'default'
32- return f'{ self .key_prefix } :{ thread_id } '
35+ thread_id = "default"
36+
37+ return f"{ self .key_prefix } :{ thread_id } "
3338
3439 # restore a checkpoint
3540 def get_tuple (self , config : RunnableConfig ) -> CheckpointTuple | None :
3641 key = self ._get_key (config )
37- with DaprClient () as client :
38- resp = client .get_state (store_name = self .store_name , key = key )
39- if not resp .data :
40- return None
4142
42- wrapper = json .loads (resp .data )
43- cp_data = wrapper .get ('checkpoint' , wrapper )
44- metadata = wrapper .get ('metadata' , {'step' : 0 })
45- if 'step' not in metadata :
46- metadata ['step' ] = 0
43+ resp = self .client .get_state (store_name = self .store_name , key = key )
44+ if not resp .data :
45+ return None
4746
48- cp = Checkpoint (** cp_data )
49- return CheckpointTuple (
50- config = config ,
51- checkpoint = cp ,
52- parent_config = None ,
53- metadata = metadata ,
54- )
47+ wrapper = json .loads (resp .data )
48+ cp_data = wrapper .get ('checkpoint' , wrapper )
49+ metadata = wrapper .get ('metadata' , {'step' : 0 })
50+ if 'step' not in metadata :
51+ metadata ['step' ] = 0
52+
53+ cp = Checkpoint (** cp_data )
54+ return CheckpointTuple (
55+ config = config ,
56+ checkpoint = cp ,
57+ parent_config = None ,
58+ metadata = metadata ,
59+ )
5560
5661 # save a full checkpoint snapshot
5762 def put (
@@ -61,27 +66,30 @@ def put(
6166 parent_config : RunnableConfig | None ,
6267 metadata : dict [str , Any ],
6368 ) -> None :
69+
6470 key = self ._get_key (config )
65- with DaprClient () as client :
66- checkpoint_serializable = {
67- 'v' : checkpoint ['v' ],
68- 'id' : checkpoint ['id' ],
69- 'ts' : checkpoint ['ts' ],
70- 'channel_values' : checkpoint ['channel_values' ],
71- 'channel_versions' : checkpoint ['channel_versions' ],
72- 'versions_seen' : checkpoint ['versions_seen' ],
73- }
74- wrapper = {'checkpoint' : checkpoint_serializable , 'metadata' : metadata }
75-
76- # Save checkpoint to Dapr
77- client .save_state (self .store_name , key , dumps (wrapper ))
78-
79- # Maintain registry of all checkpoint keys
80- reg_resp = client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
81- registry = json .loads (reg_resp .data ) if reg_resp .data else []
82- if key not in registry :
83- registry .append (key )
84- client .save_state (self .store_name , self .REGISTRY_KEY , json .dumps (registry ))
71+
72+ checkpoint_serializable = {
73+ 'v' : checkpoint ['v' ],
74+ 'id' : checkpoint ['id' ],
75+ 'ts' : checkpoint ['ts' ],
76+ 'channel_values' : checkpoint ['channel_values' ],
77+ 'channel_versions' : checkpoint ['channel_versions' ],
78+ 'versions_seen' : checkpoint ['versions_seen' ],
79+ }
80+
81+ wrapper = {'checkpoint' : checkpoint_serializable , 'metadata' : metadata }
82+
83+ self .client .save_state (self .store_name , key , dumps (wrapper ))
84+
85+ reg_resp = self .client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
86+ registry = json .loads (reg_resp .data ) if reg_resp .data else []
87+
88+ if key not in registry :
89+ registry .append (key )
90+ self .client .save_state (
91+ self .store_name , self .REGISTRY_KEY , json .dumps (registry )
92+ )
8593
8694 # incremental persistence (for streamed runs)
8795 def put_writes (
@@ -91,68 +99,70 @@ def put_writes(
9199 task_id : str ,
92100 task_path : str = '' ,
93101 ) -> None :
94- """Persist incremental updates for streaming or async workflows."""
102+
103+ _ = task_id , task_path
104+
95105 key = self ._get_key (config )
96- with DaprClient () as client :
97- resp = client .get_state (store_name = self .store_name , key = key )
98- if not resp .data :
99- return
100106
101- wrapper = json .loads (resp .data )
102- cp = wrapper .get ('checkpoint' , {})
107+ resp = self .client .get_state (store_name = self .store_name , key = key )
108+ if not resp .data :
109+ return
110+
111+ wrapper = json .loads (resp .data )
112+ cp = wrapper .get ('checkpoint' , {})
103113
104- for field , value in writes :
105- cp ['channel_values' ][field ] = value
114+ for field , value in writes :
115+ cp ['channel_values' ][field ] = value
106116
107- wrapper ['checkpoint' ] = cp
108- client .save_state (self .store_name , key , json .dumps (wrapper ))
117+ wrapper ['checkpoint' ] = cp
118+ self . client .save_state (self .store_name , key , json .dumps (wrapper ))
109119
110120 # enumerate all saved checkpoints
111121 def list (self , config : RunnableConfig ) -> list [CheckpointTuple ]:
112- with DaprClient () as client :
113- reg_resp = client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
114- if not reg_resp .data :
115- return []
116-
117- keys = json .loads (reg_resp .data )
118- checkpoints : list [CheckpointTuple ] = []
119-
120- for key in keys :
121- cp_resp = client .get_state (store_name = self .store_name , key = key )
122- if not cp_resp .data :
123- continue
124-
125- wrapper = json .loads (cp_resp .data )
126- cp_data = wrapper .get ('checkpoint' , {})
127- metadata = wrapper .get ('metadata' , {})
128- cp = Checkpoint (** cp_data )
129-
130- checkpoints .append (
131- CheckpointTuple (
132- config = config ,
133- checkpoint = cp ,
134- parent_config = None ,
135- metadata = metadata ,
136- )
122+ reg_resp = self .client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
123+ if not reg_resp .data :
124+ return []
125+
126+ keys = json .loads (reg_resp .data )
127+ checkpoints : list [CheckpointTuple ] = []
128+
129+ for key in keys :
130+ cp_resp = self .client .get_state (store_name = self .store_name , key = key )
131+ if not cp_resp .data :
132+ continue
133+
134+ wrapper = json .loads (cp_resp .data )
135+ cp_data = wrapper .get ('checkpoint' , {})
136+ metadata = wrapper .get ('metadata' , {})
137+ cp = Checkpoint (** cp_data )
138+
139+ checkpoints .append (
140+ CheckpointTuple (
141+ config = config ,
142+ checkpoint = cp ,
143+ parent_config = None ,
144+ metadata = metadata ,
137145 )
146+ )
138147
139- return checkpoints
148+ return checkpoints
140149
141150 # remove a checkpoint and update the registry
142151 def delete_thread (self , config : RunnableConfig ) -> None :
143152 key = self ._get_key (config )
144- with DaprClient () as client :
145- client .delete_state (store_name = self .store_name , key = key )
146-
147- reg_resp = client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
148- if not reg_resp .data :
149- return
150-
151- registry = json .loads (reg_resp .data )
152- if key in registry :
153- registry .remove (key )
154- client .save_state (
155- store_name = self .store_name ,
156- key = self .REGISTRY_KEY ,
157- value = json .dumps (registry ),
158- )
153+
154+ self .client .delete_state (store_name = self .store_name , key = key )
155+
156+ reg_resp = self .client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
157+ if not reg_resp .data :
158+ return
159+
160+ registry = json .loads (reg_resp .data )
161+
162+ if key in registry :
163+ registry .remove (key )
164+ self .client .save_state (
165+ store_name = self .store_name ,
166+ key = self .REGISTRY_KEY ,
167+ value = json .dumps (registry ),
168+ )
0 commit comments