1- from typing import Sequence , Tuple , Any
21import json
32from typing import Any , Sequence , Tuple
3+
4+ from langchain_core .load import dumps
5+ from langchain_core .runnables import RunnableConfig
6+
47from dapr .clients import DaprClient
58from langgraph .checkpoint .base import BaseCheckpointSaver , Checkpoint , CheckpointTuple
6- from langchain_core .runnables import RunnableConfig
7- from langchain_core .load import dumps
89
910
1011class DaprCheckpointer (BaseCheckpointSaver [Checkpoint ]):
@@ -13,7 +14,7 @@ class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]):
1314 Compatible with LangGraph >= 0.3.6 and LangChain Core >= 1.0.0.
1415 """
1516
16- REGISTRY_KEY = " dapr_checkpoint_registry"
17+ REGISTRY_KEY = ' dapr_checkpoint_registry'
1718
1819 def __init__ (self , store_name : str , key_prefix : str ):
1920 self .store_name = store_name
@@ -23,12 +24,12 @@ def __init__(self, store_name: str, key_prefix: str):
2324 def _get_key (self , config : RunnableConfig ) -> str :
2425 thread_id = None
2526 if isinstance (config , dict ):
26- thread_id = config .get (" configurable" , {}).get (" thread_id" )
27+ thread_id = config .get (' configurable' , {}).get (' thread_id' )
2728 if not thread_id :
28- thread_id = config .get (" thread_id" )
29+ thread_id = config .get (' thread_id' )
2930 if not thread_id :
30- thread_id = " default"
31- return f" { self .key_prefix } :{ thread_id } "
31+ thread_id = ' default'
32+ return f' { self .key_prefix } :{ thread_id } '
3233
3334 # restore a checkpoint
3435 def get_tuple (self , config : RunnableConfig ) -> CheckpointTuple | None :
@@ -39,10 +40,10 @@ def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
3940 return None
4041
4142 wrapper = json .loads (resp .data )
42- cp_data = wrapper .get (" checkpoint" , wrapper )
43- metadata = wrapper .get (" metadata" , {" step" : 0 })
44- if " step" not in metadata :
45- metadata [" step" ] = 0
43+ cp_data = wrapper .get (' checkpoint' , wrapper )
44+ metadata = wrapper .get (' metadata' , {' step' : 0 })
45+ if ' step' not in metadata :
46+ metadata [' step' ] = 0
4647
4748 cp = Checkpoint (** cp_data )
4849 return CheckpointTuple (
@@ -63,34 +64,32 @@ def put(
6364 key = self ._get_key (config )
6465 with DaprClient () as client :
6566 checkpoint_serializable = {
66- "v" : checkpoint ["v" ],
67- "id" : checkpoint ["id" ],
68- "ts" : checkpoint ["ts" ],
69- " channel_values" : checkpoint [" channel_values" ],
70- " channel_versions" : checkpoint [" channel_versions" ],
71- " versions_seen" : checkpoint [" versions_seen" ],
67+ 'v' : checkpoint ['v' ],
68+ 'id' : checkpoint ['id' ],
69+ 'ts' : checkpoint ['ts' ],
70+ ' channel_values' : checkpoint [' channel_values' ],
71+ ' channel_versions' : checkpoint [' channel_versions' ],
72+ ' versions_seen' : checkpoint [' versions_seen' ],
7273 }
73- wrapper = {" checkpoint" : checkpoint_serializable , " metadata" : metadata }
74+ wrapper = {' checkpoint' : checkpoint_serializable , ' metadata' : metadata }
7475
7576 # Save checkpoint to Dapr
7677 client .save_state (self .store_name , key , dumps (wrapper ))
77-
78+
7879 # Maintain registry of all checkpoint keys
7980 reg_resp = client .get_state (store_name = self .store_name , key = self .REGISTRY_KEY )
8081 registry = json .loads (reg_resp .data ) if reg_resp .data else []
8182 if key not in registry :
8283 registry .append (key )
83- client .save_state (
84- self .store_name , self .REGISTRY_KEY , json .dumps (registry )
85- )
84+ client .save_state (self .store_name , self .REGISTRY_KEY , json .dumps (registry ))
8685
8786 # incremental persistence (for streamed runs)
8887 def put_writes (
8988 self ,
9089 config : RunnableConfig ,
9190 writes : Sequence [Tuple [str , Any ]],
9291 task_id : str ,
93- task_path : str = "" ,
92+ task_path : str = '' ,
9493 ) -> None :
9594 """Persist incremental updates for streaming or async workflows."""
9695 key = self ._get_key (config )
@@ -100,12 +99,12 @@ def put_writes(
10099 return
101100
102101 wrapper = json .loads (resp .data )
103- cp = wrapper .get (" checkpoint" , {})
102+ cp = wrapper .get (' checkpoint' , {})
104103
105104 for field , value in writes :
106- cp [" channel_values" ][field ] = value
105+ cp [' channel_values' ][field ] = value
107106
108- wrapper [" checkpoint" ] = cp
107+ wrapper [' checkpoint' ] = cp
109108 client .save_state (self .store_name , key , json .dumps (wrapper ))
110109
111110 # enumerate all saved checkpoints
@@ -124,8 +123,8 @@ def list(self, config: RunnableConfig) -> list[CheckpointTuple]:
124123 continue
125124
126125 wrapper = json .loads (cp_resp .data )
127- cp_data = wrapper .get (" checkpoint" , {})
128- metadata = wrapper .get (" metadata" , {})
126+ cp_data = wrapper .get (' checkpoint' , {})
127+ metadata = wrapper .get (' metadata' , {})
129128 cp = Checkpoint (** cp_data )
130129
131130 checkpoints .append (
@@ -139,7 +138,6 @@ def list(self, config: RunnableConfig) -> list[CheckpointTuple]:
139138
140139 return checkpoints
141140
142-
143141 # remove a checkpoint and update the registry
144142 def delete_thread (self , config : RunnableConfig ) -> None :
145143 key = self ._get_key (config )
@@ -158,4 +156,3 @@ def delete_thread(self, config: RunnableConfig) -> None:
158156 key = self .REGISTRY_KEY ,
159157 value = json .dumps (registry ),
160158 )
161-
0 commit comments