11# -*- coding: utf-8 -*-
22
3- import unittest
4- from unittest import mock
53import json
4+ import unittest
65from datetime import datetime
6+ from unittest import mock
77
8- from langgraph .checkpoint .base import Checkpoint
98from dapr .ext .langgraph .dapr_checkpointer import DaprCheckpointer
9+ from langgraph .checkpoint .base import Checkpoint
1010
1111
12- @mock .patch (" dapr.ext.langgraph.dapr_checkpointer.DaprClient" )
12+ @mock .patch (' dapr.ext.langgraph.dapr_checkpointer.DaprClient' )
1313class DaprCheckpointerTest (unittest .TestCase ):
14-
1514 def setUp (self ):
16- self .store = " statestore"
17- self .prefix = "lg"
18- self .config = {" configurable" : {" thread_id" : "t1" }}
15+ self .store = ' statestore'
16+ self .prefix = 'lg'
17+ self .config = {' configurable' : {' thread_id' : 't1' }}
1918
2019 self .checkpoint = Checkpoint (
2120 v = 1 ,
22- id = " cp1" ,
21+ id = ' cp1' ,
2322 ts = datetime .now ().timestamp (),
24- channel_values = {"a" : 1 },
23+ channel_values = {'a' : 1 },
2524 channel_versions = {},
2625 versions_seen = {},
2726 )
2827
29-
3028 def test_get_tuple_returns_checkpoint (self , mock_client_cls ):
3129 mock_client = mock_client_cls .return_value
3230
3331 wrapper = {
34- " checkpoint" : {
35- "v" : self .checkpoint ["v" ],
36- "id" : self .checkpoint ["id" ],
37- "ts" : self .checkpoint ["ts" ],
38- " channel_values" : self .checkpoint [" channel_values" ],
39- " channel_versions" : self .checkpoint [" channel_versions" ],
40- " versions_seen" : self .checkpoint [" versions_seen" ],
32+ ' checkpoint' : {
33+ 'v' : self .checkpoint ['v' ],
34+ 'id' : self .checkpoint ['id' ],
35+ 'ts' : self .checkpoint ['ts' ],
36+ ' channel_values' : self .checkpoint [' channel_values' ],
37+ ' channel_versions' : self .checkpoint [' channel_versions' ],
38+ ' versions_seen' : self .checkpoint [' versions_seen' ],
4139 },
42- " metadata" : {" step" : 3 },
40+ ' metadata' : {' step' : 3 },
4341 }
4442 mock_client .get_state .return_value .data = json .dumps (wrapper )
4543
4644 cp = DaprCheckpointer (self .store , self .prefix )
4745 tup = cp .get_tuple (self .config )
4846
4947 assert tup is not None
50- assert tup .checkpoint ["id" ] == " cp1"
51- assert tup .metadata [" step" ] == 3
48+ assert tup .checkpoint ['id' ] == ' cp1'
49+ assert tup .metadata [' step' ] == 3
5250
5351 def test_get_tuple_none_when_missing (self , mock_client_cls ):
5452 mock_client = mock_client_cls .return_value
@@ -57,65 +55,62 @@ def test_get_tuple_none_when_missing(self, mock_client_cls):
5755 cp = DaprCheckpointer (self .store , self .prefix )
5856 assert cp .get_tuple (self .config ) is None
5957
60-
6158 def test_put_saves_checkpoint_and_registry (self , mock_client_cls ):
6259 mock_client = mock_client_cls .return_value
6360
6461 mock_client .get_state .return_value .data = json .dumps ([])
6562
6663 cp = DaprCheckpointer (self .store , self .prefix )
67- cp .put (self .config , self .checkpoint , None , {" step" : 10 })
64+ cp .put (self .config , self .checkpoint , None , {' step' : 10 })
6865
6966 first_call = mock_client .save_state .call_args_list [0 ][0 ]
70- assert first_call [0 ] == " statestore"
71- assert first_call [1 ] == " lg:t1"
67+ assert first_call [0 ] == ' statestore'
68+ assert first_call [1 ] == ' lg:t1'
7269 saved_payload = json .loads (first_call [2 ])
73- assert saved_payload [" metadata" ][ " step" ] == 10
70+ assert saved_payload [' metadata' ][ ' step' ] == 10
7471
7572 second_call = mock_client .save_state .call_args_list [1 ][0 ]
76- assert second_call [0 ] == " statestore"
73+ assert second_call [0 ] == ' statestore'
7774 assert second_call [1 ] == DaprCheckpointer .REGISTRY_KEY
7875
79-
8076 def test_put_writes_updates_channel_values (self , mock_client_cls ):
8177 mock_client = mock_client_cls .return_value
8278
8379 wrapper = {
84- " checkpoint" : {
85- "v" : 1 ,
86- "id" : " cp1" ,
87- "ts" : 1000 ,
88- " channel_values" : {"a" : 10 },
89- " channel_versions" : {},
90- " versions_seen" : {},
80+ ' checkpoint' : {
81+ 'v' : 1 ,
82+ 'id' : ' cp1' ,
83+ 'ts' : 1000 ,
84+ ' channel_values' : {'a' : 10 },
85+ ' channel_versions' : {},
86+ ' versions_seen' : {},
9187 },
92- " metadata" : {},
88+ ' metadata' : {},
9389 }
9490 mock_client .get_state .return_value .data = json .dumps (wrapper )
9591
9692 cp = DaprCheckpointer (self .store , self .prefix )
97- cp .put_writes (self .config , writes = [("a" , 99 )], task_id = " task1" )
93+ cp .put_writes (self .config , writes = [('a' , 99 )], task_id = ' task1' )
9894
9995 # save_state is called with updated checkpoint
10096 call = mock_client .save_state .call_args [0 ]
10197 saved = json .loads (call [2 ])
102- assert saved ["checkpoint" ]["channel_values" ]["a" ] == 99
103-
98+ assert saved ['checkpoint' ]['channel_values' ]['a' ] == 99
10499
105100 def test_list_returns_all_checkpoints (self , mock_client_cls ):
106101 mock_client = mock_client_cls .return_value
107102
108- registry = [" lg:t1" ]
103+ registry = [' lg:t1' ]
109104 cp_wrapper = {
110- " checkpoint" : {
111- "v" : 1 ,
112- "id" : " cp1" ,
113- "ts" : 1000 ,
114- " channel_values" : {"x" : 1 },
115- " channel_versions" : {},
116- " versions_seen" : {},
105+ ' checkpoint' : {
106+ 'v' : 1 ,
107+ 'id' : ' cp1' ,
108+ 'ts' : 1000 ,
109+ ' channel_values' : {'x' : 1 },
110+ ' channel_versions' : {},
111+ ' versions_seen' : {},
117112 },
118- " metadata" : {" step" : 5 },
113+ ' metadata' : {' step' : 5 },
119114 }
120115
121116 mock_client .get_state .side_effect = [
@@ -127,30 +122,29 @@ def test_list_returns_all_checkpoints(self, mock_client_cls):
127122 lst = cp .list (self .config )
128123
129124 assert len (lst ) == 1
130- assert lst [0 ].checkpoint ["id" ] == "cp1"
131- assert lst [0 ].metadata ["step" ] == 5
132-
125+ assert lst [0 ].checkpoint ['id' ] == 'cp1'
126+ assert lst [0 ].metadata ['step' ] == 5
133127
134128 def test_delete_thread_removes_key_and_updates_registry (self , mock_client_cls ):
135129 mock_client = mock_client_cls .return_value
136130
137- registry = [" lg:t1" ]
131+ registry = [' lg:t1' ]
138132 mock_client .get_state .return_value .data = json .dumps (registry )
139133
140134 cp = DaprCheckpointer (self .store , self .prefix )
141135 cp .delete_thread (self .config )
142136
143137 mock_client .delete_state .assert_called_once_with (
144- store_name = " statestore" ,
145- key = " lg:t1" ,
138+ store_name = ' statestore' ,
139+ key = ' lg:t1' ,
146140 )
147141
148142 mock_client .save_state .assert_called_with (
149- store_name = " statestore" ,
143+ store_name = ' statestore' ,
150144 key = DaprCheckpointer .REGISTRY_KEY ,
151145 value = json .dumps ([]),
152146 )
153147
154148
155- if __name__ == " __main__" :
149+ if __name__ == ' __main__' :
156150 unittest .main ()
0 commit comments