2424from ._client_tool_call import ClientToolCallState
2525from ._constants import CHATKIT_THREAD_METADTA_KEY , CLIENT_TOOL_KEY_IN_TOOL_RESPONSE , WIDGET_KEY_IN_TOOL_RESPONSE
2626from ._context import ADKContext
27- from ._thread_utils import serialize_thread_metadata
27+ from ._thread_utils import (
28+ add_client_tool_status ,
29+ get_client_tool_status ,
30+ get_thread_metadata_from_state ,
31+ serialize_thread_metadata ,
32+ )
2833
2934
3035def _to_user_message_content (event : Event ) -> list [UserMessageContent ]:
@@ -67,7 +72,7 @@ async def load_thread(self, thread_id: str, context: ADKContext) -> ThreadMetada
6772 f"Session with id { thread_id } not found for user { context ['user_id' ]} in app { context ['app_name' ]} "
6873 )
6974
70- return ThreadMetadata . model_validate (session .state [ CHATKIT_THREAD_METADTA_KEY ] )
75+ return get_thread_metadata_from_state (session .state )
7176
7277 async def save_thread (self , thread : ThreadMetadata , context : ADKContext ) -> None :
7378 session = await self ._session_service .get_session (
@@ -159,15 +164,20 @@ async def load_thread_items(
159164 adk_client_tool = fn_response .response .get (CLIENT_TOOL_KEY_IN_TOOL_RESPONSE , None )
160165 if adk_client_tool :
161166 adk_client_tool = ClientToolCallState .model_validate (adk_client_tool )
162- an_item = ClientToolCallItem (
163- id = event .id ,
164- thread_id = thread_id ,
165- name = adk_client_tool .name ,
166- arguments = adk_client_tool .arguments ,
167- status = adk_client_tool .status ,
168- created_at = datetime .fromtimestamp (event .timestamp ),
169- call_id = fn_response .id ,
167+ status = get_client_tool_status (
168+ session .state ,
169+ adk_client_tool .id ,
170170 )
171+ if status :
172+ an_item = ClientToolCallItem (
173+ id = event .id ,
174+ thread_id = thread_id ,
175+ name = adk_client_tool .name ,
176+ arguments = adk_client_tool .arguments ,
177+ status = status , # type: ignore
178+ created_at = datetime .fromtimestamp (event .timestamp ),
179+ call_id = adk_client_tool .id ,
180+ )
171181
172182 if an_item :
173183 thread_items .append (an_item )
@@ -188,8 +198,7 @@ async def delete_attachment(self, attachment_id: str, context: ADKContext) -> No
188198 raise NotImplementedError ()
189199
190200 async def delete_thread_item (self , thread_id : str , item_id : str , context : ADKContext ) -> None :
191- # deletion is called primarily to remove the ClientToolCallItem calls
192- # we simply ignore them here as they are not stored separately
201+ # simply ignoring it for now (ClientToolCallItem is typically not deleted because of this)
193202 pass
194203
195204 async def delete_thread (self , thread_id : str , context : ADKContext ) -> None :
@@ -198,7 +207,32 @@ async def delete_thread(self, thread_id: str, context: ADKContext) -> None:
198207 async def save_item (self , thread_id : str , item : ThreadItem , context : ADKContext ) -> None :
199208 # we will only handle specify types of items here
200209 # as quite many are automatically handled by runner
201- pass
210+ if isinstance (item , ClientToolCallItem ):
211+ session = await self ._session_service .get_session (
212+ app_name = context ["app_name" ],
213+ user_id = context ["user_id" ],
214+ session_id = thread_id ,
215+ )
216+
217+ if not session :
218+ raise ValueError (
219+ f"Session with id { thread_id } not found for user { context ['user_id' ]} in app { context ['app_name' ]} "
220+ )
221+
222+ thread_metadata = add_client_tool_status (session .state , item .call_id , item .status )
223+
224+ state_delta = {
225+ CHATKIT_THREAD_METADTA_KEY : serialize_thread_metadata (thread_metadata ),
226+ }
227+
228+ actions_with_update = EventActions (state_delta = state_delta )
229+ system_event = Event (
230+ invocation_id = uuid4 ().hex ,
231+ author = "system" ,
232+ actions = actions_with_update ,
233+ timestamp = datetime .now ().timestamp (),
234+ )
235+ await self ._session_service .append_event (session , system_event )
202236
203237 async def load_item (self , thread_id : str , item_id : str , context : ADKContext ) -> ThreadItem :
204238 raise NotImplementedError ()
@@ -218,7 +252,7 @@ async def load_threads(
218252 items : list [ThreadMetadata ] = []
219253
220254 for session in sessions_response .sessions :
221- thread_metatdata_item = ThreadMetadata . model_validate (session .state [ CHATKIT_THREAD_METADTA_KEY ] )
222- items .append (thread_metatdata_item )
255+ thread_metadata = get_thread_metadata_from_state (session .state )
256+ items .append (thread_metadata )
223257
224258 return Page (data = items )
0 commit comments