Skip to content

Commit 06c2401

Browse files
committed
add tests
Signed-off-by: yaron2 <[email protected]>
1 parent 5ffee0f commit 06c2401

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ pip3 install -e .
8686
pip3 install -e ./ext/dapr-ext-grpc/
8787
pip3 install -e ./ext/dapr-ext-fastapi/
8888
pip3 install -e ./ext/dapr-ext-workflow/
89+
pip3 install -e ./ext/dapr-ext-langgraph/
8990
```
9091

9192
3. Install required packages
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Copyright 2025 The Dapr Authors
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import unittest
4+
from unittest import mock
5+
import json
6+
from datetime import datetime
7+
8+
from langgraph.checkpoint.base import Checkpoint
9+
from dapr.ext.langgraph.dapr_checkpointer import DaprCheckpointer
10+
11+
12+
@mock.patch("dapr.ext.langgraph.dapr_checkpointer.DaprClient")
13+
class DaprCheckpointerTest(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.store = "statestore"
17+
self.prefix = "lg"
18+
self.config = {"configurable": {"thread_id": "t1"}}
19+
20+
self.checkpoint = Checkpoint(
21+
v=1,
22+
id="cp1",
23+
ts=datetime.now().timestamp(),
24+
channel_values={"a": 1},
25+
channel_versions={},
26+
versions_seen={},
27+
)
28+
29+
30+
def test_get_tuple_returns_checkpoint(self, mock_client_cls):
31+
mock_client = mock_client_cls.return_value
32+
33+
wrapper = {
34+
"checkpoint": {
35+
"v": self.checkpoint["v"],
36+
"id": self.checkpoint["id"],
37+
"ts": self.checkpoint["ts"],
38+
"channel_values": self.checkpoint["channel_values"],
39+
"channel_versions": self.checkpoint["channel_versions"],
40+
"versions_seen": self.checkpoint["versions_seen"],
41+
},
42+
"metadata": {"step": 3},
43+
}
44+
mock_client.get_state.return_value.data = json.dumps(wrapper)
45+
46+
cp = DaprCheckpointer(self.store, self.prefix)
47+
tup = cp.get_tuple(self.config)
48+
49+
assert tup is not None
50+
assert tup.checkpoint["id"] == "cp1"
51+
assert tup.metadata["step"] == 3
52+
53+
def test_get_tuple_none_when_missing(self, mock_client_cls):
54+
mock_client = mock_client_cls.return_value
55+
mock_client.get_state.return_value.data = None
56+
57+
cp = DaprCheckpointer(self.store, self.prefix)
58+
assert cp.get_tuple(self.config) is None
59+
60+
61+
def test_put_saves_checkpoint_and_registry(self, mock_client_cls):
62+
mock_client = mock_client_cls.return_value
63+
64+
mock_client.get_state.return_value.data = json.dumps([])
65+
66+
cp = DaprCheckpointer(self.store, self.prefix)
67+
cp.put(self.config, self.checkpoint, None, {"step": 10})
68+
69+
first_call = mock_client.save_state.call_args_list[0][0]
70+
assert first_call[0] == "statestore"
71+
assert first_call[1] == "lg:t1"
72+
saved_payload = json.loads(first_call[2])
73+
assert saved_payload["metadata"]["step"] == 10
74+
75+
second_call = mock_client.save_state.call_args_list[1][0]
76+
assert second_call[0] == "statestore"
77+
assert second_call[1] == DaprCheckpointer.REGISTRY_KEY
78+
79+
80+
def test_put_writes_updates_channel_values(self, mock_client_cls):
81+
mock_client = mock_client_cls.return_value
82+
83+
wrapper = {
84+
"checkpoint": {
85+
"v": 1,
86+
"id": "cp1",
87+
"ts": 1000,
88+
"channel_values": {"a": 10},
89+
"channel_versions": {},
90+
"versions_seen": {},
91+
},
92+
"metadata": {},
93+
}
94+
mock_client.get_state.return_value.data = json.dumps(wrapper)
95+
96+
cp = DaprCheckpointer(self.store, self.prefix)
97+
cp.put_writes(self.config, writes=[("a", 99)], task_id="task1")
98+
99+
# save_state is called with updated checkpoint
100+
call = mock_client.save_state.call_args[0]
101+
saved = json.loads(call[2])
102+
assert saved["checkpoint"]["channel_values"]["a"] == 99
103+
104+
105+
def test_list_returns_all_checkpoints(self, mock_client_cls):
106+
mock_client = mock_client_cls.return_value
107+
108+
registry = ["lg:t1"]
109+
cp_wrapper = {
110+
"checkpoint": {
111+
"v": 1,
112+
"id": "cp1",
113+
"ts": 1000,
114+
"channel_values": {"x": 1},
115+
"channel_versions": {},
116+
"versions_seen": {},
117+
},
118+
"metadata": {"step": 5},
119+
}
120+
121+
mock_client.get_state.side_effect = [
122+
mock.Mock(data=json.dumps(registry)),
123+
mock.Mock(data=json.dumps(cp_wrapper)),
124+
]
125+
126+
cp = DaprCheckpointer(self.store, self.prefix)
127+
lst = cp.list(self.config)
128+
129+
assert len(lst) == 1
130+
assert lst[0].checkpoint["id"] == "cp1"
131+
assert lst[0].metadata["step"] == 5
132+
133+
134+
def test_delete_thread_removes_key_and_updates_registry(self, mock_client_cls):
135+
mock_client = mock_client_cls.return_value
136+
137+
registry = ["lg:t1"]
138+
mock_client.get_state.return_value.data = json.dumps(registry)
139+
140+
cp = DaprCheckpointer(self.store, self.prefix)
141+
cp.delete_thread(self.config)
142+
143+
mock_client.delete_state.assert_called_once_with(
144+
store_name="statestore",
145+
key="lg:t1",
146+
)
147+
148+
mock_client.save_state.assert_called_with(
149+
store_name="statestore",
150+
key=DaprCheckpointer.REGISTRY_KEY,
151+
value=json.dumps([]),
152+
)
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main()

0 commit comments

Comments
 (0)