Skip to content

Commit a281075

Browse files
committed
linter
Signed-off-by: yaron2 <[email protected]>
1 parent 06c2401 commit a281075

File tree

1 file changed

+50
-56
lines changed

1 file changed

+50
-56
lines changed
Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,52 @@
11
# -*- coding: utf-8 -*-
22

3-
import unittest
4-
from unittest import mock
53
import json
4+
import unittest
65
from datetime import datetime
6+
from unittest import mock
77

8-
from langgraph.checkpoint.base import Checkpoint
98
from dapr.ext.langgraph.dapr_checkpointer import DaprCheckpointer
9+
from langgraph.checkpoint.base import Checkpoint
1010

1111

12-
@mock.patch("dapr.ext.langgraph.dapr_checkpointer.DaprClient")
12+
@mock.patch('dapr.ext.langgraph.dapr_checkpointer.DaprClient')
1313
class DaprCheckpointerTest(unittest.TestCase):
14-
1514
def setUp(self):
16-
self.store = "statestore"
17-
self.prefix = "lg"
18-
self.config = {"configurable": {"thread_id": "t1"}}
15+
self.store = 'statestore'
16+
self.prefix = 'lg'
17+
self.config = {'configurable': {'thread_id': 't1'}}
1918

2019
self.checkpoint = Checkpoint(
2120
v=1,
22-
id="cp1",
21+
id='cp1',
2322
ts=datetime.now().timestamp(),
24-
channel_values={"a": 1},
23+
channel_values={'a': 1},
2524
channel_versions={},
2625
versions_seen={},
2726
)
2827

29-
3028
def test_get_tuple_returns_checkpoint(self, mock_client_cls):
3129
mock_client = mock_client_cls.return_value
3230

3331
wrapper = {
34-
"checkpoint": {
35-
"v": self.checkpoint["v"],
36-
"id": self.checkpoint["id"],
37-
"ts": self.checkpoint["ts"],
38-
"channel_values": self.checkpoint["channel_values"],
39-
"channel_versions": self.checkpoint["channel_versions"],
40-
"versions_seen": self.checkpoint["versions_seen"],
32+
'checkpoint': {
33+
'v': self.checkpoint['v'],
34+
'id': self.checkpoint['id'],
35+
'ts': self.checkpoint['ts'],
36+
'channel_values': self.checkpoint['channel_values'],
37+
'channel_versions': self.checkpoint['channel_versions'],
38+
'versions_seen': self.checkpoint['versions_seen'],
4139
},
42-
"metadata": {"step": 3},
40+
'metadata': {'step': 3},
4341
}
4442
mock_client.get_state.return_value.data = json.dumps(wrapper)
4543

4644
cp = DaprCheckpointer(self.store, self.prefix)
4745
tup = cp.get_tuple(self.config)
4846

4947
assert tup is not None
50-
assert tup.checkpoint["id"] == "cp1"
51-
assert tup.metadata["step"] == 3
48+
assert tup.checkpoint['id'] == 'cp1'
49+
assert tup.metadata['step'] == 3
5250

5351
def test_get_tuple_none_when_missing(self, mock_client_cls):
5452
mock_client = mock_client_cls.return_value
@@ -57,65 +55,62 @@ def test_get_tuple_none_when_missing(self, mock_client_cls):
5755
cp = DaprCheckpointer(self.store, self.prefix)
5856
assert cp.get_tuple(self.config) is None
5957

60-
6158
def test_put_saves_checkpoint_and_registry(self, mock_client_cls):
6259
mock_client = mock_client_cls.return_value
6360

6461
mock_client.get_state.return_value.data = json.dumps([])
6562

6663
cp = DaprCheckpointer(self.store, self.prefix)
67-
cp.put(self.config, self.checkpoint, None, {"step": 10})
64+
cp.put(self.config, self.checkpoint, None, {'step': 10})
6865

6966
first_call = mock_client.save_state.call_args_list[0][0]
70-
assert first_call[0] == "statestore"
71-
assert first_call[1] == "lg:t1"
67+
assert first_call[0] == 'statestore'
68+
assert first_call[1] == 'lg:t1'
7269
saved_payload = json.loads(first_call[2])
73-
assert saved_payload["metadata"]["step"] == 10
70+
assert saved_payload['metadata']['step'] == 10
7471

7572
second_call = mock_client.save_state.call_args_list[1][0]
76-
assert second_call[0] == "statestore"
73+
assert second_call[0] == 'statestore'
7774
assert second_call[1] == DaprCheckpointer.REGISTRY_KEY
7875

79-
8076
def test_put_writes_updates_channel_values(self, mock_client_cls):
8177
mock_client = mock_client_cls.return_value
8278

8379
wrapper = {
84-
"checkpoint": {
85-
"v": 1,
86-
"id": "cp1",
87-
"ts": 1000,
88-
"channel_values": {"a": 10},
89-
"channel_versions": {},
90-
"versions_seen": {},
80+
'checkpoint': {
81+
'v': 1,
82+
'id': 'cp1',
83+
'ts': 1000,
84+
'channel_values': {'a': 10},
85+
'channel_versions': {},
86+
'versions_seen': {},
9187
},
92-
"metadata": {},
88+
'metadata': {},
9389
}
9490
mock_client.get_state.return_value.data = json.dumps(wrapper)
9591

9692
cp = DaprCheckpointer(self.store, self.prefix)
97-
cp.put_writes(self.config, writes=[("a", 99)], task_id="task1")
93+
cp.put_writes(self.config, writes=[('a', 99)], task_id='task1')
9894

9995
# save_state is called with updated checkpoint
10096
call = mock_client.save_state.call_args[0]
10197
saved = json.loads(call[2])
102-
assert saved["checkpoint"]["channel_values"]["a"] == 99
103-
98+
assert saved['checkpoint']['channel_values']['a'] == 99
10499

105100
def test_list_returns_all_checkpoints(self, mock_client_cls):
106101
mock_client = mock_client_cls.return_value
107102

108-
registry = ["lg:t1"]
103+
registry = ['lg:t1']
109104
cp_wrapper = {
110-
"checkpoint": {
111-
"v": 1,
112-
"id": "cp1",
113-
"ts": 1000,
114-
"channel_values": {"x": 1},
115-
"channel_versions": {},
116-
"versions_seen": {},
105+
'checkpoint': {
106+
'v': 1,
107+
'id': 'cp1',
108+
'ts': 1000,
109+
'channel_values': {'x': 1},
110+
'channel_versions': {},
111+
'versions_seen': {},
117112
},
118-
"metadata": {"step": 5},
113+
'metadata': {'step': 5},
119114
}
120115

121116
mock_client.get_state.side_effect = [
@@ -127,30 +122,29 @@ def test_list_returns_all_checkpoints(self, mock_client_cls):
127122
lst = cp.list(self.config)
128123

129124
assert len(lst) == 1
130-
assert lst[0].checkpoint["id"] == "cp1"
131-
assert lst[0].metadata["step"] == 5
132-
125+
assert lst[0].checkpoint['id'] == 'cp1'
126+
assert lst[0].metadata['step'] == 5
133127

134128
def test_delete_thread_removes_key_and_updates_registry(self, mock_client_cls):
135129
mock_client = mock_client_cls.return_value
136130

137-
registry = ["lg:t1"]
131+
registry = ['lg:t1']
138132
mock_client.get_state.return_value.data = json.dumps(registry)
139133

140134
cp = DaprCheckpointer(self.store, self.prefix)
141135
cp.delete_thread(self.config)
142136

143137
mock_client.delete_state.assert_called_once_with(
144-
store_name="statestore",
145-
key="lg:t1",
138+
store_name='statestore',
139+
key='lg:t1',
146140
)
147141

148142
mock_client.save_state.assert_called_with(
149-
store_name="statestore",
143+
store_name='statestore',
150144
key=DaprCheckpointer.REGISTRY_KEY,
151145
value=json.dumps([]),
152146
)
153147

154148

155-
if __name__ == "__main__":
149+
if __name__ == '__main__':
156150
unittest.main()

0 commit comments

Comments
 (0)