Skip to content

Commit 8eed645

Browse files
committed
test: update test case to handle new output format
Signed-off-by: Casper Nielsen <[email protected]>
1 parent 2479e64 commit 8eed645

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

ext/dapr-ext-langgraph/tests/test_checkpointer.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import json
44
import unittest
5+
import msgpack
6+
import base64
57
from datetime import datetime
68
from unittest import mock
79

@@ -61,17 +63,33 @@ def test_put_saves_checkpoint_and_registry(self, mock_client_cls):
6163
mock_client.get_state.return_value.data = json.dumps([])
6264

6365
cp = DaprCheckpointer(self.store, self.prefix)
64-
cp.put(self.config, self.checkpoint, None, {'step': 10})
65-
66-
first_call = mock_client.save_state.call_args_list[0][0]
67-
assert first_call[0] == 'statestore'
68-
assert first_call[1] == 'lg:t1'
69-
saved_payload = json.loads(first_call[2])
66+
cp.put(self.config, self.checkpoint, {'step': 10}, None)
67+
68+
first_call = mock_client.save_state.call_args_list[0]
69+
first_call_kwargs = first_call.kwargs
70+
assert first_call_kwargs['store_name'] == 'statestore'
71+
assert first_call_kwargs['key'] == 'checkpoint:t1::cp1'
72+
unpacked = msgpack.unpackb(first_call_kwargs['value']) # We're packing bytes
73+
saved_payload = {}
74+
for k, v in unpacked.items():
75+
k = k.decode() if isinstance(k, bytes) else k
76+
if k == 'checkpoint' or k == 'metadata': # Need to convert b'' on checkpoint/metadata dict key/values
77+
if k == 'metadata':
78+
v = msgpack.unpackb(v) # Metadata value is packed
79+
val = {}
80+
for sk, sv in v.items():
81+
sk = sk.decode() if isinstance(sk, bytes) else sk
82+
sv = sv.decode() if isinstance(sv, bytes) else sv
83+
val[sk] = sv
84+
else:
85+
val = v.decode() if isinstance(v, bytes) else v
86+
saved_payload[k] = val
7087
assert saved_payload['metadata']['step'] == 10
7188

72-
second_call = mock_client.save_state.call_args_list[1][0]
73-
assert second_call[0] == 'statestore'
74-
assert second_call[1] == DaprCheckpointer.REGISTRY_KEY
89+
second_call = mock_client.save_state.call_args_list[1]
90+
second_call_kwargs = second_call.kwargs
91+
assert second_call_kwargs['store_name'] == 'statestore'
92+
assert second_call_kwargs['value'] == 'checkpoint:t1::cp1' # Here we're testing if the last checkpoint is the first_call above
7593

7694
def test_put_writes_updates_channel_values(self, mock_client_cls):
7795
mock_client = mock_client_cls.return_value
@@ -93,9 +111,12 @@ def test_put_writes_updates_channel_values(self, mock_client_cls):
93111
cp.put_writes(self.config, writes=[('a', 99)], task_id='task1')
94112

95113
# save_state is called with updated checkpoint
96-
call = mock_client.save_state.call_args[0]
97-
saved = json.loads(call[2])
98-
assert saved['checkpoint']['channel_values']['a'] == 99
114+
call = mock_client.save_state.call_args_list[0]
115+
# As we're using named input params we've got to fetch through kwargs
116+
kwargs = call.kwargs
117+
saved = json.loads(kwargs['value'])
118+
# As the value obj is base64 encoded in 'blob' we got to unpack it
119+
assert msgpack.unpackb(base64.b64decode(saved['blob'])) == 99
99120

100121
def test_list_returns_all_checkpoints(self, mock_client_cls):
101122
mock_client = mock_client_cls.return_value

0 commit comments

Comments
 (0)