Skip to content

Commit 8fa1521

Browse files
committed
linter
Signed-off-by: yaron2 <[email protected]>
1 parent b62991d commit 8fa1521

File tree

3 files changed

+32
-33
lines changed

3 files changed

+32
-33
lines changed

dapr-ext-langgraph/dapr/ext/langgraph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818

1919
__all__ = [
2020
'DaprCheckpointer',
21-
]
21+
]

dapr-ext-langgraph/dapr/ext/langgraph/dapr_checkpointer.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Sequence, Tuple, Any
21
import json
32
from typing import Any, Sequence, Tuple
3+
4+
from langchain_core.load import dumps
5+
from langchain_core.runnables import RunnableConfig
6+
47
from dapr.clients import DaprClient
58
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointTuple
6-
from langchain_core.runnables import RunnableConfig
7-
from langchain_core.load import dumps
89

910

1011
class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]):
@@ -13,7 +14,7 @@ class DaprCheckpointer(BaseCheckpointSaver[Checkpoint]):
1314
Compatible with LangGraph >= 0.3.6 and LangChain Core >= 1.0.0.
1415
"""
1516

16-
REGISTRY_KEY = "dapr_checkpoint_registry"
17+
REGISTRY_KEY = 'dapr_checkpoint_registry'
1718

1819
def __init__(self, store_name: str, key_prefix: str):
1920
self.store_name = store_name
@@ -23,12 +24,12 @@ def __init__(self, store_name: str, key_prefix: str):
2324
def _get_key(self, config: RunnableConfig) -> str:
2425
thread_id = None
2526
if isinstance(config, dict):
26-
thread_id = config.get("configurable", {}).get("thread_id")
27+
thread_id = config.get('configurable', {}).get('thread_id')
2728
if not thread_id:
28-
thread_id = config.get("thread_id")
29+
thread_id = config.get('thread_id')
2930
if not thread_id:
30-
thread_id = "default"
31-
return f"{self.key_prefix}:{thread_id}"
31+
thread_id = 'default'
32+
return f'{self.key_prefix}:{thread_id}'
3233

3334
# restore a checkpoint
3435
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
@@ -39,10 +40,10 @@ def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
3940
return None
4041

4142
wrapper = json.loads(resp.data)
42-
cp_data = wrapper.get("checkpoint", wrapper)
43-
metadata = wrapper.get("metadata", {"step": 0})
44-
if "step" not in metadata:
45-
metadata["step"] = 0
43+
cp_data = wrapper.get('checkpoint', wrapper)
44+
metadata = wrapper.get('metadata', {'step': 0})
45+
if 'step' not in metadata:
46+
metadata['step'] = 0
4647

4748
cp = Checkpoint(**cp_data)
4849
return CheckpointTuple(
@@ -63,34 +64,32 @@ def put(
6364
key = self._get_key(config)
6465
with DaprClient() as client:
6566
checkpoint_serializable = {
66-
"v": checkpoint["v"],
67-
"id": checkpoint["id"],
68-
"ts": checkpoint["ts"],
69-
"channel_values": checkpoint["channel_values"],
70-
"channel_versions": checkpoint["channel_versions"],
71-
"versions_seen": checkpoint["versions_seen"],
67+
'v': checkpoint['v'],
68+
'id': checkpoint['id'],
69+
'ts': checkpoint['ts'],
70+
'channel_values': checkpoint['channel_values'],
71+
'channel_versions': checkpoint['channel_versions'],
72+
'versions_seen': checkpoint['versions_seen'],
7273
}
73-
wrapper = {"checkpoint": checkpoint_serializable, "metadata": metadata}
74+
wrapper = {'checkpoint': checkpoint_serializable, 'metadata': metadata}
7475

7576
# Save checkpoint to Dapr
7677
client.save_state(self.store_name, key, dumps(wrapper))
77-
78+
7879
# Maintain registry of all checkpoint keys
7980
reg_resp = client.get_state(store_name=self.store_name, key=self.REGISTRY_KEY)
8081
registry = json.loads(reg_resp.data) if reg_resp.data else []
8182
if key not in registry:
8283
registry.append(key)
83-
client.save_state(
84-
self.store_name, self.REGISTRY_KEY, json.dumps(registry)
85-
)
84+
client.save_state(self.store_name, self.REGISTRY_KEY, json.dumps(registry))
8685

8786
# incremental persistence (for streamed runs)
8887
def put_writes(
8988
self,
9089
config: RunnableConfig,
9190
writes: Sequence[Tuple[str, Any]],
9291
task_id: str,
93-
task_path: str = "",
92+
task_path: str = '',
9493
) -> None:
9594
"""Persist incremental updates for streaming or async workflows."""
9695
key = self._get_key(config)
@@ -100,12 +99,12 @@ def put_writes(
10099
return
101100

102101
wrapper = json.loads(resp.data)
103-
cp = wrapper.get("checkpoint", {})
102+
cp = wrapper.get('checkpoint', {})
104103

105104
for field, value in writes:
106-
cp["channel_values"][field] = value
105+
cp['channel_values'][field] = value
107106

108-
wrapper["checkpoint"] = cp
107+
wrapper['checkpoint'] = cp
109108
client.save_state(self.store_name, key, json.dumps(wrapper))
110109

111110
# enumerate all saved checkpoints
@@ -124,8 +123,8 @@ def list(self, config: RunnableConfig) -> list[CheckpointTuple]:
124123
continue
125124

126125
wrapper = json.loads(cp_resp.data)
127-
cp_data = wrapper.get("checkpoint", {})
128-
metadata = wrapper.get("metadata", {})
126+
cp_data = wrapper.get('checkpoint', {})
127+
metadata = wrapper.get('metadata', {})
129128
cp = Checkpoint(**cp_data)
130129

131130
checkpoints.append(
@@ -139,7 +138,6 @@ def list(self, config: RunnableConfig) -> list[CheckpointTuple]:
139138

140139
return checkpoints
141140

142-
143141
# remove a checkpoint and update the registry
144142
def delete_thread(self, config: RunnableConfig) -> None:
145143
key = self._get_key(config)
@@ -158,4 +156,3 @@ def delete_thread(self, config: RunnableConfig) -> None:
158156
key=self.REGISTRY_KEY,
159157
value=json.dumps(registry),
160158
)
161-

dapr-ext-langgraph/setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def is_release():
5151
name += '-dev'
5252
version = f'{__version__}{build_number}'
5353
description = 'The developmental release for the Dapr Checkpointer extension for LangGraph'
54-
long_description = 'This is the developmental release for the Dapr Checkpointer extension for LangGraph'
54+
long_description = (
55+
'This is the developmental release for the Dapr Checkpointer extension for LangGraph'
56+
)
5557

5658
print(f'package name: {name}, version: {version}', flush=True)
5759

0 commit comments

Comments
 (0)