Skip to content

Commit 11b29bd

Browse files
committed
feat: major rewrite of widget management / introduce Widgets backend & frontend that shows interactive gallery of widgets
1 parent 24729b6 commit 11b29bd

File tree

35 files changed

+1006
-166
lines changed

35 files changed

+1006
-166
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ It uses and extends `openai/chatkit-python` (https://github.com/openai/chatkit-p
88
- A function (`stream_agent_response`) that translate ADK events into chatkit events
99
- Provides support to render `widgets`
1010
* See examples/backend/src/backend/agents/facts/_tools.py::get_weather
11-
* Use `add_widget_to_tool_response` in your tool and `widget` will be sent to client
1211
- Provides support for making calls to client tools.
1312
* Client tools typically run in browser
1413
* See examples/backend/src/backend/agents/facts/_tools.py::switch_theme
@@ -51,7 +50,7 @@ See `examples` for full usage
5150

5251
```python
5352

54-
from adk_chatkit import ADKContext, ADKStore, stream_agent_response
53+
from adk_chatkit import ADKAgentContext, ADKContext, ADKStore, ChatkitRunConfig, stream_agent_response
5554

5655
class FactsChatkitServer(ChatKitServer[ADKContext]):
5756
def __init__(
@@ -80,19 +79,20 @@ class FactsChatkitServer(ChatKitServer[ADKContext]):
8079
if not message_text:
8180
return
8281

83-
content = genai_types.Content(
84-
role="user",
85-
parts=[genai_types.Part.from_text(text=message_text)],
82+
agent_context = ADKAgentContext(
83+
app_name=context.app_name,
84+
user_id=context.user_id,
85+
thread=thread,
8686
)
8787

8888
event_stream = self._runner.run_async(
89-
user_id=context["user_id"],
89+
user_id=context.user_id,
9090
session_id=thread.id,
9191
new_message=content,
92-
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
92+
run_config=ChatkitRunConfig(streaming_mode=StreamingMode.SSE, context=agent_context),
9393
)
9494

95-
async for event in stream_agent_response(thread, event_stream):
95+
async for event in stream_agent_response(agent_context, event_stream):
9696
yield event
9797

9898
```
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from .__about__ import __application__, __author__, __version__
2-
from ._callbacks import remove_widgets_and_client_tool_calls
2+
from ._callbacks import remove_client_tool_calls
33
from ._client_tool_call import ClientToolCallState, add_client_tool_call_to_tool_response
4-
from ._context import ADKContext
4+
from ._context import ADKAgentContext, ADKContext, ChatkitRunConfig
55
from ._response import stream_agent_response
66
from ._store import ADKStore
7-
from ._widgets import add_widget_to_tool_response
7+
from ._widgets import serialize_widget_item
88

99
__all__ = [
1010
"__version__",
1111
"__application__",
1212
"__author__",
1313
"ADKContext",
14+
"ADKAgentContext",
1415
"ADKStore",
1516
"stream_agent_response",
1617
"ClientToolCallState",
1718
"add_client_tool_call_to_tool_response",
18-
"remove_widgets_and_client_tool_calls",
19-
"add_widget_to_tool_response",
19+
"remove_client_tool_calls",
20+
"ChatkitRunConfig",
21+
"serialize_widget_item",
2022
]

adk-chatkit/src/adk_chatkit/_callbacks.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,17 @@
22
from google.adk.models.llm_request import LlmRequest
33
from google.adk.models.llm_response import LlmResponse
44

5-
from ._constants import CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, WIDGET_KEY_IN_TOOL_RESPONSE
5+
from ._constants import CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
66

77

8-
def remove_widgets_and_client_tool_calls(
9-
callback_context: CallbackContext, llm_request: LlmRequest
10-
) -> LlmResponse | None:
8+
def remove_client_tool_calls(callback_context: CallbackContext, llm_request: LlmRequest) -> LlmResponse | None:
119
for c in llm_request.contents:
1210
if c.parts is None:
1311
continue
1412
for p in c.parts:
1513
if not p.function_response:
1614
continue
1715
if p.function_response.response:
18-
p.function_response.response.pop(WIDGET_KEY_IN_TOOL_RESPONSE, None)
1916
p.function_response.response.pop(CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, None)
2017

2118
return None
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Final
22

3-
WIDGET_KEY_IN_TOOL_RESPONSE: Final[str] = "adk-widget"
43
CLIENT_TOOL_KEY_IN_TOOL_RESPONSE: Final[str] = "adk-client-tool"
54
CHATKIT_THREAD_METADTA_KEY: Final[str] = "adk-chatkit-thread-metadata"
5+
CHATKIT_WIDGET_STATE_KEY: Final[str] = "adk-chatkit-widget-state"
Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
1-
from typing import TypedDict
1+
import asyncio
22

3+
from chatkit.types import ThreadMetadata, ThreadStreamEvent
4+
from google.adk.runners import RunConfig
5+
from pydantic import BaseModel
36

4-
class ADKContext(TypedDict):
7+
from ._event_utils import QueueCompleteSentinel
8+
9+
10+
class ADKContext(BaseModel):
511
app_name: str
612
user_id: str
13+
14+
15+
class ADKAgentContext(ADKContext):
16+
thread: ThreadMetadata
17+
18+
_events: asyncio.Queue[ThreadStreamEvent | QueueCompleteSentinel] = asyncio.Queue()
19+
20+
async def stream(self, event: ThreadStreamEvent) -> None:
21+
await self._events.put(event)
22+
23+
def _complete(self) -> None:
24+
self._events.put_nowait(QueueCompleteSentinel())
25+
26+
27+
class ChatkitRunConfig(RunConfig):
28+
context: ADKAgentContext
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import asyncio
2+
from collections.abc import AsyncIterator
3+
from typing import TypeVar
4+
5+
from chatkit.types import ThreadStreamEvent
6+
7+
T1 = TypeVar("T1")
8+
T2 = TypeVar("T2")
9+
10+
11+
class QueueCompleteSentinel: ...
12+
13+
14+
async def merge_generators(
15+
a: AsyncIterator[T1],
16+
b: AsyncIterator[T2],
17+
) -> AsyncIterator[T1 | T2]:
18+
pending: list[AsyncIterator[T1 | T2]] = [a, b]
19+
pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
20+
asyncio.ensure_future(g.__anext__()): g for g in pending
21+
}
22+
while len(pending_tasks) > 0:
23+
done, _ = await asyncio.wait(pending_tasks.keys(), return_when="FIRST_COMPLETED")
24+
stop = False
25+
for d in done:
26+
try:
27+
result = d.result()
28+
yield result
29+
dg = pending_tasks[d]
30+
pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
31+
except StopAsyncIteration:
32+
stop = True
33+
finally:
34+
del pending_tasks[d]
35+
if stop:
36+
for task in pending_tasks.keys():
37+
if not task.cancel():
38+
try:
39+
yield task.result()
40+
except asyncio.CancelledError:
41+
pass
42+
except asyncio.InvalidStateError:
43+
pass
44+
break
45+
46+
47+
class EventWrapper:
48+
def __init__(self, event: ThreadStreamEvent):
49+
self.event = event
50+
51+
52+
class AsyncQueueIterator(AsyncIterator[EventWrapper]):
53+
def __init__(self, queue: asyncio.Queue[ThreadStreamEvent | QueueCompleteSentinel]):
54+
self.queue = queue
55+
self.completed = False
56+
57+
def __aiter__(self) -> AsyncIterator[EventWrapper]:
58+
return self
59+
60+
async def __anext__(self) -> EventWrapper:
61+
if self.completed:
62+
raise StopAsyncIteration
63+
64+
item = await self.queue.get()
65+
if isinstance(item, QueueCompleteSentinel):
66+
self.completed = True
67+
raise StopAsyncIteration
68+
return EventWrapper(item)
69+
70+
def drain_and_complete(self) -> None:
71+
"""Empty the underlying queue without awaiting and mark this iterator completed.
72+
73+
This is intended for cleanup paths where we must guarantee no awaits
74+
occur. All queued items, including any completion sentinel, are
75+
discarded.
76+
"""
77+
while True:
78+
try:
79+
self.queue.get_nowait()
80+
except asyncio.QueueEmpty:
81+
break
82+
self.completed = True

adk-chatkit/src/adk_chatkit/_response.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,59 @@
1414
ThreadItemUpdated,
1515
ThreadMetadata,
1616
ThreadStreamEvent,
17-
WidgetItem,
1817
)
1918
from google.adk.events import Event
2019

2120
from ._client_tool_call import ClientToolCallState
22-
from ._constants import CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, WIDGET_KEY_IN_TOOL_RESPONSE
21+
from ._constants import CLIENT_TOOL_KEY_IN_TOOL_RESPONSE
22+
from ._context import ADKAgentContext
23+
from ._event_utils import AsyncQueueIterator, EventWrapper, merge_generators
2324

2425

25-
async def stream_agent_response(
26+
async def _handle_function_response(
27+
event: Event,
2628
thread: ThreadMetadata,
29+
) -> AsyncGenerator[ThreadItemDoneEvent, None]:
30+
if fn_responses := event.get_function_responses():
31+
for fn_response in fn_responses:
32+
if not fn_response.response:
33+
continue
34+
35+
adk_client_tool: ClientToolCallState | None = fn_response.response.get(
36+
CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, None
37+
)
38+
if adk_client_tool:
39+
yield ThreadItemDoneEvent(
40+
item=ClientToolCallItem(
41+
id=event.id,
42+
thread_id=thread.id,
43+
name=adk_client_tool.name,
44+
arguments=adk_client_tool.arguments,
45+
status=adk_client_tool.status,
46+
created_at=datetime.fromtimestamp(event.timestamp),
47+
call_id=adk_client_tool.id,
48+
),
49+
)
50+
51+
52+
async def stream_agent_response(
53+
context: ADKAgentContext,
2754
adk_response: AsyncGenerator[Event, None],
2855
) -> AsyncIterator[ThreadStreamEvent]:
29-
if adk_response is None:
30-
return
31-
56+
queue_iterator = AsyncQueueIterator(context._events)
3257
response_id = str(uuid.uuid4())
3358

59+
thread = context.thread
60+
3461
content_index = 0
35-
async for event in adk_response:
62+
async for event in merge_generators(adk_response, queue_iterator):
63+
if event is None:
64+
continue
65+
66+
if isinstance(event, EventWrapper):
67+
yield event.event
68+
continue
69+
3670
if event.content is None:
3771
# we need to throw item added event first
3872
yield ThreadItemAddedEvent(
@@ -53,38 +87,8 @@ async def stream_agent_response(
5387
),
5488
)
5589
else:
56-
# Since Widgets are recorded in the function responses
57-
# they are handled here
58-
if fn_responses := event.get_function_responses():
59-
for fn_response in fn_responses:
60-
if not fn_response.response:
61-
continue
62-
widget = fn_response.response.get(WIDGET_KEY_IN_TOOL_RESPONSE, None)
63-
if widget:
64-
# No Streaming for Widgets for now
65-
yield ThreadItemDoneEvent(
66-
item=WidgetItem(
67-
id=str(uuid.uuid4()),
68-
thread_id=thread.id,
69-
created_at=datetime.fromtimestamp(event.timestamp),
70-
widget=widget,
71-
)
72-
)
73-
adk_client_tool: ClientToolCallState | None = fn_response.response.get(
74-
CLIENT_TOOL_KEY_IN_TOOL_RESPONSE, None
75-
)
76-
if adk_client_tool:
77-
yield ThreadItemDoneEvent(
78-
item=ClientToolCallItem(
79-
id=event.id,
80-
thread_id=thread.id,
81-
name=adk_client_tool.name,
82-
arguments=adk_client_tool.arguments,
83-
status=adk_client_tool.status,
84-
created_at=datetime.fromtimestamp(event.timestamp),
85-
call_id=adk_client_tool.id,
86-
),
87-
)
90+
async for item in _handle_function_response(event, thread):
91+
yield item
8892

8993
if event.content.parts:
9094
text_from_final_update = ""
@@ -116,3 +120,9 @@ async def stream_agent_response(
116120
created_at=datetime.fromtimestamp(event.timestamp),
117121
)
118122
)
123+
124+
context._complete()
125+
126+
# Drain remaining events
127+
async for event in queue_iterator:
128+
yield event.event

0 commit comments

Comments
 (0)