Skip to content

Commit b82e662

Browse files
committed
feat: leverage session state to manage the status of client tool call and response
1 parent 0e4ed5c commit b82e662

File tree

5 files changed

+105
-24
lines changed

5 files changed

+105
-24
lines changed
Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
from typing import Any, Literal
2+
from uuid import uuid4
23

3-
from pydantic import BaseModel
4+
from google.adk.tools import ToolContext
5+
from pydantic import BaseModel, Field
46

5-
from ._constants import CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
7+
from ._constants import CHATKIT_THREAD_METADTA_KEY, CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
8+
from ._thread_utils import (
9+
add_client_tool_status,
10+
serialize_thread_metadata,
11+
)
612

713

814
class ClientToolCallState(BaseModel):
915
"""
1016
Returned from tool methods to indicate a client-side tool call.
1117
"""
1218

19+
id: str = Field(default_factory=lambda: uuid4().hex)
20+
1321
name: str
1422
arguments: dict[str, Any]
1523
status: Literal["pending", "completed"] = "pending"
@@ -18,11 +26,22 @@ class ClientToolCallState(BaseModel):
1826
def add_client_tool_call_to_tool_response(
1927
response: dict[str, Any],
2028
client_tool_call: ClientToolCallState,
29+
tool_context: ToolContext,
2130
) -> None:
2231
"""Add a client tool call to a tool response dictionary.
2332
2433
Args:
2534
response: The tool response dictionary to modify.
2635
client_tool_call: The client tool call state to add.
2736
"""
37+
38+
thread_metadata = add_client_tool_status(
39+
tool_context.state,
40+
client_tool_call.id,
41+
client_tool_call.status,
42+
)
43+
44+
# update the state
45+
tool_context.state[CHATKIT_THREAD_METADTA_KEY] = serialize_thread_metadata(thread_metadata)
46+
2847
response[CLIENT_TOOL_KEY_IN_TOOL_RESPONSE] = client_tool_call

adk-chatkit/src/adk_chatkit/_response.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def stream_agent_response(
4040
id=response_id,
4141
content=[],
4242
thread_id=thread.id,
43-
created_at=datetime.now(),
43+
created_at=datetime.fromtimestamp(event.timestamp),
4444
)
4545
)
4646

@@ -76,13 +76,13 @@ async def stream_agent_response(
7676
if adk_client_tool:
7777
yield ThreadItemDoneEvent(
7878
item=ClientToolCallItem(
79-
id=str(uuid.uuid4()),
79+
id=event.id,
8080
thread_id=thread.id,
8181
name=adk_client_tool.name,
8282
arguments=adk_client_tool.arguments,
8383
status=adk_client_tool.status,
8484
created_at=datetime.fromtimestamp(event.timestamp),
85-
call_id=fn_response.id, # type: ignore
85+
call_id=adk_client_tool.id,
8686
),
8787
)
8888

@@ -113,6 +113,6 @@ async def stream_agent_response(
113113
id=response_id,
114114
content=[AssistantMessageContent(text=text_from_final_update)],
115115
thread_id=thread.id,
116-
created_at=datetime.now(),
116+
created_at=datetime.fromtimestamp(event.timestamp),
117117
)
118118
)

adk-chatkit/src/adk_chatkit/_store.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from ._client_tool_call import ClientToolCallState
2525
from ._constants import CHATKIT_THREAD_METADTA_KEY, CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, WIDGET_KEY_IN_TOOL_RESPONSE
2626
from ._context import ADKContext
27-
from ._thread_utils import serialize_thread_metadata
27+
from ._thread_utils import (
28+
add_client_tool_status,
29+
get_client_tool_status,
30+
get_thread_metadata_from_state,
31+
serialize_thread_metadata,
32+
)
2833

2934

3035
def _to_user_message_content(event: Event) -> list[UserMessageContent]:
@@ -67,7 +72,7 @@ async def load_thread(self, thread_id: str, context: ADKContext) -> ThreadMetada
6772
f"Session with id {thread_id} not found for user {context['user_id']} in app {context['app_name']}"
6873
)
6974

70-
return ThreadMetadata.model_validate(session.state[CHATKIT_THREAD_METADTA_KEY])
75+
return get_thread_metadata_from_state(session.state)
7176

7277
async def save_thread(self, thread: ThreadMetadata, context: ADKContext) -> None:
7378
session = await self._session_service.get_session(
@@ -159,15 +164,20 @@ async def load_thread_items(
159164
adk_client_tool = fn_response.response.get(CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, None)
160165
if adk_client_tool:
161166
adk_client_tool = ClientToolCallState.model_validate(adk_client_tool)
162-
an_item = ClientToolCallItem(
163-
id=event.id,
164-
thread_id=thread_id,
165-
name=adk_client_tool.name,
166-
arguments=adk_client_tool.arguments,
167-
status=adk_client_tool.status,
168-
created_at=datetime.fromtimestamp(event.timestamp),
169-
call_id=fn_response.id,
167+
status = get_client_tool_status(
168+
session.state,
169+
adk_client_tool.id,
170170
)
171+
if status:
172+
an_item = ClientToolCallItem(
173+
id=event.id,
174+
thread_id=thread_id,
175+
name=adk_client_tool.name,
176+
arguments=adk_client_tool.arguments,
177+
status=status, # type: ignore
178+
created_at=datetime.fromtimestamp(event.timestamp),
179+
call_id=adk_client_tool.id,
180+
)
171181

172182
if an_item:
173183
thread_items.append(an_item)
@@ -188,8 +198,7 @@ async def delete_attachment(self, attachment_id: str, context: ADKContext) -> No
188198
raise NotImplementedError()
189199

190200
async def delete_thread_item(self, thread_id: str, item_id: str, context: ADKContext) -> None:
191-
# deletion is called primarily to remove the ClientToolCallItem calls
192-
# we simply ignore them here as they are not stored separately
201+
# simply ignoring it for now (ClientToolCallItem is typically not deleted because of this)
193202
pass
194203

195204
async def delete_thread(self, thread_id: str, context: ADKContext) -> None:
@@ -198,7 +207,32 @@ async def delete_thread(self, thread_id: str, context: ADKContext) -> None:
198207
async def save_item(self, thread_id: str, item: ThreadItem, context: ADKContext) -> None:
199208
# we will only handle specify types of items here
200209
# as quite many are automatically handled by runner
201-
pass
210+
if isinstance(item, ClientToolCallItem):
211+
session = await self._session_service.get_session(
212+
app_name=context["app_name"],
213+
user_id=context["user_id"],
214+
session_id=thread_id,
215+
)
216+
217+
if not session:
218+
raise ValueError(
219+
f"Session with id {thread_id} not found for user {context['user_id']} in app {context['app_name']}"
220+
)
221+
222+
thread_metadata = add_client_tool_status(session.state, item.call_id, item.status)
223+
224+
state_delta = {
225+
CHATKIT_THREAD_METADTA_KEY: serialize_thread_metadata(thread_metadata),
226+
}
227+
228+
actions_with_update = EventActions(state_delta=state_delta)
229+
system_event = Event(
230+
invocation_id=uuid4().hex,
231+
author="system",
232+
actions=actions_with_update,
233+
timestamp=datetime.now().timestamp(),
234+
)
235+
await self._session_service.append_event(session, system_event)
202236

203237
async def load_item(self, thread_id: str, item_id: str, context: ADKContext) -> ThreadItem:
204238
raise NotImplementedError()
@@ -218,7 +252,7 @@ async def load_threads(
218252
items: list[ThreadMetadata] = []
219253

220254
for session in sessions_response.sessions:
221-
thread_metatdata_item = ThreadMetadata.model_validate(session.state[CHATKIT_THREAD_METADTA_KEY])
222-
items.append(thread_metatdata_item)
255+
thread_metadata = get_thread_metadata_from_state(session.state)
256+
items.append(thread_metadata)
223257

224258
return Page(data=items)

adk-chatkit/src/adk_chatkit/_thread_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,36 @@
22
from typing import Any
33

44
from chatkit.types import ThreadMetadata
5+
from google.adk.sessions.state import State
6+
7+
from ._constants import CHATKIT_THREAD_METADTA_KEY
8+
9+
_CLIENT_TOOL_PREFIX = "client-tool"
510

611

712
def serialize_thread_metadata(thread: ThreadMetadata) -> dict[str, Any]:
813
json_dump = thread.model_dump_json(exclude_none=True, exclude={"items"})
914
return json.loads(json_dump) # type: ignore
15+
16+
17+
def get_thread_metadata_from_state(state: State | dict[str, Any]) -> ThreadMetadata:
18+
thread_metadata_dict = state[CHATKIT_THREAD_METADTA_KEY]
19+
return ThreadMetadata.model_validate(thread_metadata_dict)
20+
21+
22+
def add_client_tool_status(
23+
state: State | dict[str, Any],
24+
client_tool_id: str,
25+
status: str,
26+
) -> ThreadMetadata:
27+
thread_metadata = get_thread_metadata_from_state(state)
28+
thread_metadata.metadata[f"{_CLIENT_TOOL_PREFIX}-{client_tool_id}"] = status
29+
return thread_metadata
30+
31+
32+
def get_client_tool_status(
33+
state: State | dict[str, Any],
34+
client_tool_id: str,
35+
) -> str | None:
36+
thread_metadata = get_thread_metadata_from_state(state)
37+
return thread_metadata.metadata.get(f"{_CLIENT_TOOL_PREFIX}-{client_tool_id}", None)

examples/backend/src/backend/agents/facts/_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ async def save_fact(
5050

5151
result = {"fact_id": confirmed.id, "status": "saved"}
5252

53-
add_client_tool_call_to_tool_response(result, client_tool_call)
53+
add_client_tool_call_to_tool_response(result, client_tool_call, tool_context)
5454

5555
return result
5656

5757

58-
async def switch_theme(theme: str) -> dict[str, str]:
58+
async def switch_theme(theme: str, tool_context: ToolContext) -> dict[str, str]:
5959
"""Switch the chat interface between light and dark color schemes.
6060
6161
Args:
@@ -73,7 +73,7 @@ async def switch_theme(theme: str) -> dict[str, str]:
7373

7474
result = {"theme": requested}
7575

76-
add_client_tool_call_to_tool_response(result, client_tool_call)
76+
add_client_tool_call_to_tool_response(result, client_tool_call, tool_context)
7777

7878
return result
7979

0 commit comments

Comments
 (0)