22
33import json
44import unittest
5+ import msgpack
6+ import base64
57from datetime import datetime
68from unittest import mock
79
@@ -61,17 +63,33 @@ def test_put_saves_checkpoint_and_registry(self, mock_client_cls):
6163 mock_client .get_state .return_value .data = json .dumps ([])
6264
6365 cp = DaprCheckpointer (self .store , self .prefix )
64- cp .put (self .config , self .checkpoint , None , {'step' : 10 })
65-
66- first_call = mock_client .save_state .call_args_list [0 ][0 ]
67- assert first_call [0 ] == 'statestore'
68- assert first_call [1 ] == 'lg:t1'
69- saved_payload = json .loads (first_call [2 ])
66+ cp .put (self .config , self .checkpoint , {'step' : 10 }, None )
67+
68+ first_call = mock_client .save_state .call_args_list [0 ]
69+ first_call_kwargs = first_call .kwargs
70+ assert first_call_kwargs ['store_name' ] == 'statestore'
71+ assert first_call_kwargs ['key' ] == 'checkpoint:t1::cp1'
72+ unpacked = msgpack .unpackb (first_call_kwargs ['value' ]) # We're packing bytes
73+ saved_payload = {}
74+ for k , v in unpacked .items ():
75+ k = k .decode () if isinstance (k , bytes ) else k
76+ if k == 'checkpoint' or k == 'metadata' : # Need to convert b'' on checkpoint/metadata dict key/values
77+ if k == 'metadata' :
78+ v = msgpack .unpackb (v ) # Metadata value is packed
79+ val = {}
80+ for sk , sv in v .items ():
81+ sk = sk .decode () if isinstance (sk , bytes ) else sk
82+ sv = sv .decode () if isinstance (sv , bytes ) else sv
83+ val [sk ] = sv
84+ else :
85+ val = v .decode () if isinstance (v , bytes ) else v
86+ saved_payload [k ] = val
7087 assert saved_payload ['metadata' ]['step' ] == 10
7188
72- second_call = mock_client .save_state .call_args_list [1 ][0 ]
73- assert second_call [0 ] == 'statestore'
74- assert second_call [1 ] == DaprCheckpointer .REGISTRY_KEY
89+ second_call = mock_client .save_state .call_args_list [1 ]
90+ second_call_kwargs = second_call .kwargs
91+ assert second_call_kwargs ['store_name' ] == 'statestore'
92+ assert second_call_kwargs ['value' ] == 'checkpoint:t1::cp1' # Here we're testing if the last checkpoint is the first_call above
7593
7694 def test_put_writes_updates_channel_values (self , mock_client_cls ):
7795 mock_client = mock_client_cls .return_value
@@ -93,9 +111,12 @@ def test_put_writes_updates_channel_values(self, mock_client_cls):
93111 cp .put_writes (self .config , writes = [('a' , 99 )], task_id = 'task1' )
94112
95113 # save_state is called with updated checkpoint
96- call = mock_client .save_state .call_args [0 ]
97- saved = json .loads (call [2 ])
98- assert saved ['checkpoint' ]['channel_values' ]['a' ] == 99
114+ call = mock_client .save_state .call_args_list [0 ]
115+ # As we're using named input params we've got to fetch through kwargs
116+ kwargs = call .kwargs
117+ saved = json .loads (kwargs ['value' ])
118+ # As the value obj is base64 encoded in 'blob' we got to unpack it
119+ assert msgpack .unpackb (base64 .b64decode (saved ['blob' ])) == 99
99120
100121 def test_list_returns_all_checkpoints (self , mock_client_cls ):
101122 mock_client = mock_client_cls .return_value
0 commit comments