Skip to content

Commit 540c666

Browse files
committed
use single client, minor improvements
Signed-off-by: yaron2 <[email protected]>
1 parent f245f6a commit 540c666

File tree

1 file changed

+102
-92
lines changed

1 file changed

+102
-92
lines changed

dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py

Lines changed: 102 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,44 @@ class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]):
1919
def __init__(self, store_name: str, key_prefix: str):
2020
self.store_name = store_name
2121
self.key_prefix = key_prefix
22+
self.client = DaprClient()
2223

2324
# helper: construct Dapr key for a thread
2425
def _get_key(self, config: RunnableConfig) -> str:
2526
thread_id = None
27+
2628
if isinstance(config, dict):
27-
thread_id = config.get('configurable', {}).get('thread_id')
29+
thread_id = config.get("configurable", {}).get("thread_id")
30+
2831
if not thread_id:
29-
thread_id = config.get('thread_id')
32+
thread_id = config.get("thread_id")
33+
3034
if not thread_id:
31-
thread_id = 'default'
32-
return f'{self.key_prefix}:{thread_id}'
35+
thread_id = "default"
36+
37+
return f"{self.key_prefix}:{thread_id}"
3338

3439
# restore a checkpoint
3540
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
3641
key = self._get_key(config)
37-
with DaprClient() as client:
38-
resp = client.get_state(store_name=self.store_name, key=key)
39-
if not resp.data:
40-
return None
4142

42-
wrapper = json.loads(resp.data)
43-
cp_data = wrapper.get('checkpoint', wrapper)
44-
metadata = wrapper.get('metadata', {'step': 0})
45-
if 'step' not in metadata:
46-
metadata['step'] = 0
43+
resp = self.client.get_state(store_name=self.store_name, key=key)
44+
if not resp.data:
45+
return None
4746

48-
cp = Checkpoint(**cp_data)
49-
return CheckpointTuple(
50-
config=config,
51-
checkpoint=cp,
52-
parent_config=None,
53-
metadata=metadata,
54-
)
47+
wrapper = json.loads(resp.data)
48+
cp_data = wrapper.get('checkpoint', wrapper)
49+
metadata = wrapper.get('metadata', {'step': 0})
50+
if 'step' not in metadata:
51+
metadata['step'] = 0
52+
53+
cp = Checkpoint(**cp_data)
54+
return CheckpointTuple(
55+
config=config,
56+
checkpoint=cp,
57+
parent_config=None,
58+
metadata=metadata,
59+
)
5560

5661
# save a full checkpoint snapshot
5762
def put(
@@ -61,27 +66,30 @@ def put(
6166
parent_config: RunnableConfig | None,
6267
metadata: dict[str, Any],
6368
) -> None:
69+
6470
key = self._get_key(config)
65-
with DaprClient() as client:
66-
checkpoint_serializable = {
67-
'v': checkpoint['v'],
68-
'id': checkpoint['id'],
69-
'ts': checkpoint['ts'],
70-
'channel_values': checkpoint['channel_values'],
71-
'channel_versions': checkpoint['channel_versions'],
72-
'versions_seen': checkpoint['versions_seen'],
73-
}
74-
wrapper = {'checkpoint': checkpoint_serializable, 'metadata': metadata}
75-
76-
# Save checkpoint to Dapr
77-
client.save_state(self.store_name, key, dumps(wrapper))
78-
79-
# Maintain registry of all checkpoint keys
80-
reg_resp = client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
81-
registry = json.loads(reg_resp.data) if reg_resp.data else []
82-
if key not in registry:
83-
registry.append(key)
84-
client.save_state(self.store_name, self.REGISTRY_KEY, json.dumps(registry))
71+
72+
checkpoint_serializable = {
73+
'v': checkpoint['v'],
74+
'id': checkpoint['id'],
75+
'ts': checkpoint['ts'],
76+
'channel_values': checkpoint['channel_values'],
77+
'channel_versions': checkpoint['channel_versions'],
78+
'versions_seen': checkpoint['versions_seen'],
79+
}
80+
81+
wrapper = {'checkpoint': checkpoint_serializable, 'metadata': metadata}
82+
83+
self.client.save_state(self.store_name, key, dumps(wrapper))
84+
85+
reg_resp = self.client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
86+
registry = json.loads(reg_resp.data) if reg_resp.data else []
87+
88+
if key not in registry:
89+
registry.append(key)
90+
self.client.save_state(
91+
self.store_name, self.REGISTRY_KEY, json.dumps(registry)
92+
)
8593

8694
# incremental persistence (for streamed runs)
8795
def put_writes(
@@ -91,68 +99,70 @@ def put_writes(
9199
task_id: str,
92100
task_path: str = '',
93101
) -> None:
94-
"""Persist incremental updates for streaming or async workflows."""
102+
103+
_ = task_id, task_path
104+
95105
key = self._get_key(config)
96-
with DaprClient() as client:
97-
resp = client.get_state(store_name=self.store_name, key=key)
98-
if not resp.data:
99-
return
100106

101-
wrapper = json.loads(resp.data)
102-
cp = wrapper.get('checkpoint', {})
107+
resp = self.client.get_state(store_name=self.store_name, key=key)
108+
if not resp.data:
109+
return
110+
111+
wrapper = json.loads(resp.data)
112+
cp = wrapper.get('checkpoint', {})
103113

104-
for field, value in writes:
105-
cp['channel_values'][field] = value
114+
for field, value in writes:
115+
cp['channel_values'][field] = value
106116

107-
wrapper['checkpoint'] = cp
108-
client.save_state(self.store_name, key, json.dumps(wrapper))
117+
wrapper['checkpoint'] = cp
118+
self.client.save_state(self.store_name, key, json.dumps(wrapper))
109119

110120
# enumerate all saved checkpoints
111121
def list(self, config: RunnableConfig) -> list[CheckpointTuple]:
112-
with DaprClient() as client:
113-
reg_resp = client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
114-
if not reg_resp.data:
115-
return []
116-
117-
keys = json.loads(reg_resp.data)
118-
checkpoints: list[CheckpointTuple] = []
119-
120-
for key in keys:
121-
cp_resp = client.get_state(store_name=self.store_name, key=key)
122-
if not cp_resp.data:
123-
continue
124-
125-
wrapper = json.loads(cp_resp.data)
126-
cp_data = wrapper.get('checkpoint', {})
127-
metadata = wrapper.get('metadata', {})
128-
cp = Checkpoint(**cp_data)
129-
130-
checkpoints.append(
131-
CheckpointTuple(
132-
config=config,
133-
checkpoint=cp,
134-
parent_config=None,
135-
metadata=metadata,
136-
)
122+
reg_resp = self.client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
123+
if not reg_resp.data:
124+
return []
125+
126+
keys = json.loads(reg_resp.data)
127+
checkpoints: list[CheckpointTuple] = []
128+
129+
for key in keys:
130+
cp_resp = self.client.get_state(store_name=self.store_name, key=key)
131+
if not cp_resp.data:
132+
continue
133+
134+
wrapper = json.loads(cp_resp.data)
135+
cp_data = wrapper.get('checkpoint', {})
136+
metadata = wrapper.get('metadata', {})
137+
cp = Checkpoint(**cp_data)
138+
139+
checkpoints.append(
140+
CheckpointTuple(
141+
config=config,
142+
checkpoint=cp,
143+
parent_config=None,
144+
metadata=metadata,
137145
)
146+
)
138147

139-
return checkpoints
148+
return checkpoints
140149

141150
# remove a checkpoint and update the registry
142151
def delete_thread(self, config: RunnableConfig) -> None:
143152
key = self._get_key(config)
144-
with DaprClient() as client:
145-
client.delete_state(store_name=self.store_name, key=key)
146-
147-
reg_resp = client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
148-
if not reg_resp.data:
149-
return
150-
151-
registry = json.loads(reg_resp.data)
152-
if key in registry:
153-
registry.remove(key)
154-
client.save_state(
155-
store_name=self.store_name,
156-
key=self.REGISTRY_KEY,
157-
value=json.dumps(registry),
158-
)
153+
154+
self.client.delete_state(store_name=self.store_name, key=key)
155+
156+
reg_resp = self.client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
157+
if not reg_resp.data:
158+
return
159+
160+
registry = json.loads(reg_resp.data)
161+
162+
if key in registry:
163+
registry.remove(key)
164+
self.client.save_state(
165+
store_name=self.store_name,
166+
key=self.REGISTRY_KEY,
167+
value=json.dumps(registry),
168+
)

0 commit comments

Comments
 (0)