Skip to content

Commit 932e6f8

Browse files
committed
feat: rewrite the client tool call support to remove the need for callbacks to process tool response
1 parent ab4c4b7 commit 932e6f8

File tree

14 files changed

+81
-169
lines changed

14 files changed

+81
-169
lines changed

adk-chatkit/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ authors = [
88
]
99
requires-python = ">=3.11"
1010
dependencies = [
11-
"openai-chatkit>=1.0.2",
11+
"openai-chatkit>=1.1.0",
1212
"google-adk>=1.16.0",
1313
"pydantic>=2.11.7",
1414
]

adk-chatkit/src/adk_chatkit/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .__about__ import __application__, __author__, __version__
2-
from ._callbacks import remove_client_tool_calls
3-
from ._client_tool_call import ClientToolCallState, add_client_tool_call_to_tool_response
2+
from ._client_tool_call import ClientToolCallState
43
from ._context import ADKAgentContext, ADKContext, ChatkitRunConfig
54
from ._response import stream_agent_response
65
from ._store import ADKStore
@@ -15,8 +14,6 @@
1514
"ADKStore",
1615
"stream_agent_response",
1716
"ClientToolCallState",
18-
"add_client_tool_call_to_tool_response",
19-
"remove_client_tool_calls",
2017
"ChatkitRunConfig",
2118
"serialize_widget_item",
2219
]

adk-chatkit/src/adk_chatkit/_callbacks.py

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1+
import json
12
from typing import Any, Literal
23
from uuid import uuid4
34

4-
from google.adk.tools import ToolContext
5+
from chatkit.types import ClientToolCallItem
56
from pydantic import BaseModel, Field
67

7-
from ._constants import CHATKIT_THREAD_METADTA_KEY, CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
8-
from ._thread_utils import (
9-
add_client_tool_status,
10-
serialize_thread_metadata,
11-
)
12-
138

149
class ClientToolCallState(BaseModel):
1510
"""
@@ -23,25 +18,6 @@ class ClientToolCallState(BaseModel):
2318
status: Literal["pending", "completed"] = "pending"
2419

2520

26-
def add_client_tool_call_to_tool_response(
27-
response: dict[str, Any],
28-
client_tool_call: ClientToolCallState,
29-
tool_context: ToolContext,
30-
) -> None:
31-
"""Add a client tool call to a tool response dictionary.
32-
33-
Args:
34-
response: The tool response dictionary to modify.
35-
client_tool_call: The client tool call state to add.
36-
"""
37-
38-
thread_metadata = add_client_tool_status(
39-
tool_context.state,
40-
client_tool_call.id,
41-
client_tool_call.status,
42-
)
43-
44-
# update the state
45-
tool_context.state[CHATKIT_THREAD_METADTA_KEY] = serialize_thread_metadata(thread_metadata)
46-
47-
response[CLIENT_TOOL_KEY_IN_TOOL_RESPONSE] = client_tool_call
21+
def serialize_client_tool_call_item(client_tool_call: ClientToolCallItem) -> dict[str, Any]:
22+
json_dump = client_tool_call.model_dump_json(exclude_none=True)
23+
return json.loads(json_dump) # type: ignore
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Final
22

3-
CLIENT_TOOL_KEY_IN_TOOL_RESPONSE: Final[str] = "adk-client-tool"
43
CHATKIT_THREAD_METADTA_KEY: Final[str] = "adk-chatkit-thread-metadata"
54
CHATKIT_WIDGET_STATE_KEY: Final[str] = "adk-chatkit-widget-state"
5+
CHATKIT_CLIENT_TOOL_CALLS_KEY: Final[str] = "adk-chatkit-client-tool-calls"

adk-chatkit/src/adk_chatkit/_context.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import asyncio
22
from datetime import datetime
33

4-
from chatkit.types import ThreadItemDoneEvent, ThreadMetadata, ThreadStreamEvent, WidgetItem
4+
from chatkit.types import ClientToolCallItem, ThreadItemDoneEvent, ThreadMetadata, ThreadStreamEvent, WidgetItem
55
from chatkit.widgets import WidgetRoot
66
from google.adk.agents.run_config import RunConfig
77
from google.adk.tools import ToolContext
88
from pydantic import BaseModel
99

10+
from ._client_tool_call import ClientToolCallState
1011
from ._event_utils import QueueCompleteSentinel
1112

1213

@@ -17,6 +18,7 @@ class ADKContext(BaseModel):
1718

1819
class ADKAgentContext(ADKContext):
1920
thread: ThreadMetadata
21+
client_tool_call: ClientToolCallItem | None = None
2022

2123
_events: asyncio.Queue[ThreadStreamEvent | QueueCompleteSentinel] = asyncio.Queue()
2224

@@ -37,6 +39,24 @@ async def stream_widget(self, widget: WidgetRoot, tool_context: ToolContext) ->
3739
)
3840
)
3941

42+
async def issue_client_tool_call(
43+
self,
44+
client_tool_call: ClientToolCallState,
45+
tool_context: ToolContext,
46+
) -> None:
47+
if tool_context.function_call_id is None:
48+
raise ValueError("tool_context.function_call_id is None")
49+
50+
self.client_tool_call = ClientToolCallItem(
51+
id=tool_context.function_call_id,
52+
thread_id=self.thread.id,
53+
name=client_tool_call.name,
54+
arguments=client_tool_call.arguments,
55+
status=client_tool_call.status,
56+
created_at=datetime.now(),
57+
call_id=client_tool_call.id,
58+
)
59+
4060
def _complete(self) -> None:
4161
self._events.put_nowait(QueueCompleteSentinel())
4262

adk-chatkit/src/adk_chatkit/_response.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,17 @@
88
AssistantMessageContentPartDone,
99
AssistantMessageContentPartTextDelta,
1010
AssistantMessageItem,
11-
ClientToolCallItem,
1211
ThreadItemAddedEvent,
1312
ThreadItemDoneEvent,
1413
ThreadItemUpdated,
15-
ThreadMetadata,
1614
ThreadStreamEvent,
1715
)
1816
from google.adk.events import Event
1917

20-
from ._client_tool_call import ClientToolCallState
21-
from ._constants import CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
2218
from ._context import ADKAgentContext
2319
from ._event_utils import AsyncQueueIterator, EventWrapper, merge_generators
2420

2521

26-
async def _handle_function_response(
27-
event: Event,
28-
thread: ThreadMetadata,
29-
) -> AsyncGenerator[ThreadItemDoneEvent, None]:
30-
if fn_responses := event.get_function_responses():
31-
for fn_response in fn_responses:
32-
if not fn_response.response:
33-
continue
34-
35-
adk_client_tool: ClientToolCallState | None = fn_response.response.get(
36-
CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, None
37-
)
38-
if adk_client_tool:
39-
yield ThreadItemDoneEvent(
40-
item=ClientToolCallItem(
41-
id=event.id,
42-
thread_id=thread.id,
43-
name=adk_client_tool.name,
44-
arguments=adk_client_tool.arguments,
45-
status=adk_client_tool.status,
46-
created_at=datetime.fromtimestamp(event.timestamp),
47-
call_id=adk_client_tool.id,
48-
),
49-
)
50-
51-
5222
async def stream_agent_response(
5323
context: ADKAgentContext,
5424
adk_response: AsyncGenerator[Event, None],
@@ -87,9 +57,6 @@ async def stream_agent_response(
8757
),
8858
)
8959
else:
90-
async for item in _handle_function_response(event, thread):
91-
yield item
92-
9360
if event.content.parts:
9461
text_from_final_update = ""
9562
for p in event.content.parts:
@@ -126,3 +93,7 @@ async def stream_agent_response(
12693
# Drain remaining events
12794
async for event in queue_iterator:
12895
yield event.event
96+
97+
# the last chatkit event is that of the client call
98+
if context.client_tool_call:
99+
yield ThreadItemDoneEvent(item=context.client_tool_call)

adk-chatkit/src/adk_chatkit/_store.py

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
from google.adk.sessions import BaseSessionService
2222
from google.adk.sessions.base_session_service import ListSessionsResponse
2323

24-
from ._client_tool_call import ClientToolCallState
25-
from ._constants import CHATKIT_THREAD_METADTA_KEY, CHATKIT_WIDGET_STATE_KEY, CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
24+
from ._client_tool_call import serialize_client_tool_call_item
25+
from ._constants import CHATKIT_CLIENT_TOOL_CALLS_KEY, CHATKIT_THREAD_METADTA_KEY, CHATKIT_WIDGET_STATE_KEY
2626
from ._context import ADKContext
2727
from ._thread_utils import (
28-
add_client_tool_status,
29-
get_client_tool_status,
3028
get_thread_metadata_from_state,
3129
serialize_thread_metadata,
3230
)
@@ -167,36 +165,24 @@ async def load_thread_items(
167165
an_item = WidgetItem.model_validate(widget_data)
168166

169167
# let's check for adk-client-tool in the response
170-
adk_client_tool = fn_response.response.get(CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, None)
171-
if adk_client_tool:
172-
adk_client_tool = ClientToolCallState.model_validate(adk_client_tool)
173-
status = get_client_tool_status(
174-
session.state,
175-
adk_client_tool.id,
176-
)
177-
if status:
178-
an_item = ClientToolCallItem(
179-
id=event.id,
180-
thread_id=thread_id,
181-
name=adk_client_tool.name,
182-
arguments=adk_client_tool.arguments,
183-
status=status, # type: ignore
184-
created_at=datetime.fromtimestamp(event.timestamp),
185-
call_id=adk_client_tool.id,
186-
)
168+
adk_client_tool = session.state.get(CHATKIT_CLIENT_TOOL_CALLS_KEY, {})
169+
if fn_response.id in adk_client_tool:
170+
client_tool_data = adk_client_tool[fn_response.id]
171+
an_item = ClientToolCallItem.model_validate(client_tool_data)
187172

188173
if an_item:
189174
thread_items.append(an_item)
190175

191176
return Page(data=thread_items)
192177

193178
async def add_thread_item(self, thread_id: str, item: ThreadItem, context: ADKContext) -> None:
194-
# items are added to the session by runner except for WidgetItem
195-
if not isinstance(item, WidgetItem):
179+
if not isinstance(item, (ClientToolCallItem, WidgetItem)):
196180
return
197181

198182
_LOGGER.info(f"Adding thread item to thread {thread_id} for user {context.user_id} in app {context.app_name}")
199183

184+
print(item)
185+
200186
# the widget item is added in a function call so it's ID has the function call id
201187
# we issue a system event to add the widget item in the State keeping the info about which function call added it
202188
# so that it is able to be retrieved later and sequenced
@@ -212,9 +198,14 @@ async def add_thread_item(self, thread_id: str, item: ThreadItem, context: ADKCo
212198
f"Session with id {thread_id} not found for user {context.user_id} in app {context.app_name}"
213199
)
214200

215-
state_delta = {
216-
CHATKIT_WIDGET_STATE_KEY: {item.id: serialize_widget_item(item)},
217-
}
201+
if isinstance(item, ClientToolCallItem):
202+
state_delta = {
203+
CHATKIT_CLIENT_TOOL_CALLS_KEY: {item.id: serialize_client_tool_call_item(item)},
204+
}
205+
elif isinstance(item, WidgetItem):
206+
state_delta = {
207+
CHATKIT_WIDGET_STATE_KEY: {item.id: serialize_widget_item(item)},
208+
}
218209

219210
actions_with_update = EventActions(state_delta=state_delta)
220211
system_event = Event(
@@ -265,38 +256,24 @@ async def save_item(self, thread_id: str, item: ThreadItem, context: ADKContext)
265256
f"Session with id {thread_id} not found for user {context.user_id} in app {context.app_name}"
266257
)
267258

268-
# we will only handle specify types of items here
269-
# as quite many are automatically handled by runner
270259
if isinstance(item, ClientToolCallItem):
271-
thread_metadata = add_client_tool_status(session.state, item.call_id, item.status)
272-
273260
state_delta = {
274-
CHATKIT_THREAD_METADTA_KEY: serialize_thread_metadata(thread_metadata),
261+
CHATKIT_CLIENT_TOOL_CALLS_KEY: {item.id: serialize_client_tool_call_item(item)},
275262
}
276-
277-
actions_with_update = EventActions(state_delta=state_delta)
278-
system_event = Event(
279-
invocation_id=uuid4().hex,
280-
author="system",
281-
actions=actions_with_update,
282-
timestamp=datetime.now().timestamp(),
283-
)
284-
await self._session_service.append_event(session, system_event)
285-
286263
elif isinstance(item, WidgetItem):
287-
# we should update the widget stored state
288264
state_delta = {
289265
CHATKIT_WIDGET_STATE_KEY: {item.id: serialize_widget_item(item)},
290266
}
291267

292-
actions_with_update = EventActions(state_delta=state_delta)
293-
system_event = Event(
294-
invocation_id=uuid4().hex,
295-
author="system",
296-
actions=actions_with_update,
297-
timestamp=datetime.now().timestamp(),
298-
)
299-
await self._session_service.append_event(session, system_event)
268+
actions_with_update = EventActions(state_delta=state_delta)
269+
system_event = Event(
270+
invocation_id=uuid4().hex,
271+
author="system",
272+
actions=actions_with_update,
273+
timestamp=datetime.now().timestamp(),
274+
)
275+
276+
await self._session_service.append_event(session, system_event)
300277

301278
async def load_item(self, thread_id: str, item_id: str, context: ADKContext) -> ThreadItem:
302279
_LOGGER.info(

adk-chatkit/src/adk_chatkit/_thread_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from ._constants import CHATKIT_THREAD_METADTA_KEY
88

9-
_CLIENT_TOOL_PREFIX = "client-tool"
10-
119

1210
def serialize_thread_metadata(thread: ThreadMetadata) -> dict[str, Any]:
1311
json_dump = thread.model_dump_json(exclude_none=True, exclude={"items"})
@@ -17,21 +15,3 @@ def serialize_thread_metadata(thread: ThreadMetadata) -> dict[str, Any]:
1715
def get_thread_metadata_from_state(state: State | dict[str, Any]) -> ThreadMetadata:
1816
thread_metadata_dict = state[CHATKIT_THREAD_METADTA_KEY]
1917
return ThreadMetadata.model_validate(thread_metadata_dict)
20-
21-
22-
def add_client_tool_status(
23-
state: State | dict[str, Any],
24-
client_tool_id: str,
25-
status: str,
26-
) -> ThreadMetadata:
27-
thread_metadata = get_thread_metadata_from_state(state)
28-
thread_metadata.metadata[f"{_CLIENT_TOOL_PREFIX}-{client_tool_id}"] = status
29-
return thread_metadata
30-
31-
32-
def get_client_tool_status(
33-
state: State | dict[str, Any],
34-
client_tool_id: str,
35-
) -> str | None:
36-
thread_metadata = get_thread_metadata_from_state(state)
37-
return thread_metadata.metadata.get(f"{_CLIENT_TOOL_PREFIX}-{client_tool_id}", None)

examples/backend/src/backend/_refreshed_session_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22

33
from google.adk.events import Event
4-
from google.adk.sessions import BaseSessionService, Session
4+
from google.adk.sessions import Session
55
from google.adk.sessions.database_session_service import DatabaseSessionService
66

77

0 commit comments

Comments
 (0)