Skip to content

Commit 35047b5

Browse files
authored
Merge pull request #29 from syedfakher27/adk-middleware
Adk middleware
2 parents 8d36475 + 1790c4b commit 35047b5

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

typescript-sdk/integrations/adk-middleware/src/adk_middleware/adk_agent.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,10 @@ async def _handle_tool_result_submission(
406406
"""
407407
thread_id = input.thread_id
408408

409-
# Extract tool results first
409+
# Extract tool results that is send by the frontend
410410
tool_results = await self._extract_tool_results(input)
411411

412+
# if the tool results are not send by the fronted then call the tool function
412413
if not tool_results:
413414
logger.error(f"Tool result submission without tool results for thread {thread_id}")
414415
yield RunErrorEvent(
@@ -714,16 +715,25 @@ async def _start_background_execution(
714715
# Create dynamic toolset if tools provided and prepare tool updates
715716
toolset = None
716717
if input.tools:
717-
toolset = ClientProxyToolset(
718-
ag_ui_tools=input.tools,
719-
event_queue=event_queue
720-
)
721718

722719
# Get existing tools from the agent
723720
existing_tools = []
724721
if hasattr(adk_agent, 'tools') and adk_agent.tools:
725722
existing_tools = list(adk_agent.tools) if isinstance(adk_agent.tools, (list, tuple)) else [adk_agent.tools]
726723

724+
# if same tool is defined in frontend and backend then agent will only use the backend tool
725+
input_tools = []
726+
for input_tool in input.tools:
727+
# Check if this input tool's name matches any existing tool
728+
if not any(hasattr(existing_tool, '__name__') and input_tool.name == existing_tool.__name__
729+
for existing_tool in existing_tools):
730+
input_tools.append(input_tool)
731+
732+
toolset = ClientProxyToolset(
733+
ag_ui_tools=input_tools,
734+
event_queue=event_queue
735+
)
736+
727737
# Combine existing tools with our proxy toolset
728738
combined_tools = existing_tools + [toolset]
729739
agent_updates['tools'] = combined_tools
@@ -859,15 +869,16 @@ async def _run_adk_in_background(
859869
logger.debug(f"Emitting event to queue: {type(ag_ui_event).__name__} (thread {input.thread_id}, queue size before: {event_queue.qsize()})")
860870
await event_queue.put(ag_ui_event)
861871
logger.debug(f"Event queued: {type(ag_ui_event).__name__} (thread {input.thread_id}, queue size after: {event_queue.qsize()})")
862-
else:
863-
final_state = await self._session_manager.get_session_state(input.thread_id,app_name,user_id)
864-
ag_ui_event = event_translator._create_state_snapshot_event(final_state)
865-
await event_queue.put(ag_ui_event)
872+
866873

867874
# Force close any streaming messages
868875
async for ag_ui_event in event_translator.force_close_streaming_message():
869876
await event_queue.put(ag_ui_event)
870-
877+
# moving states snapshot events after the text event clousure to avoid this error https://github.com/Contextable/ag-ui/issues/28
878+
final_state = await self._session_manager.get_session_state(input.thread_id,app_name,user_id)
879+
if final_state:
880+
ag_ui_event = event_translator._create_state_snapshot_event(final_state)
881+
await event_queue.put(ag_ui_event)
871882
# Signal completion - ADK execution is done
872883
logger.debug(f"Background task sending completion signal for thread {input.thread_id}")
873884
await event_queue.put(None)

typescript-sdk/integrations/adk-middleware/src/adk_middleware/event_translator.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
BaseEvent, EventType,
1212
TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent,
1313
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent,
14-
ToolCallChunkEvent,
14+
ToolCallChunkEvent,ToolCallResultEvent,
1515
StateSnapshotEvent, StateDeltaEvent,
1616
MessagesSnapshotEvent,
1717
CustomEvent,
1818
Message, AssistantMessage, UserMessage, ToolMessage
1919
)
20-
20+
import json
2121
from google.adk.events import Event as ADKEvent
2222

2323
import logging
@@ -87,13 +87,6 @@ async def translate(
8787

8888

8989

90-
# Handle function responses
91-
if hasattr(adk_event, 'get_function_responses'):
92-
function_responses = adk_event.get_function_responses()
93-
if function_responses:
94-
# Function responses are typically handled by the agent internally
95-
# We don't need to emit them as AG-UI events
96-
pass
9790

9891
# call _translate_function_calls function to yield Tool Events
9992
if hasattr(adk_event, 'get_function_calls'):
@@ -104,6 +97,15 @@ async def translate(
10497
async for event in self._translate_function_calls(function_calls):
10598
yield event
10699

100+
# Handle function responses and yield the tool response event
101+
# this is essential for scenerios when user has to render function response at frontend
102+
if hasattr(adk_event, 'get_function_responses'):
103+
function_responses = adk_event.get_function_responses()
104+
if function_responses:
105+
# Function responses should be emmitted to frontend so it can render the response as well
106+
async for event in self._translate_function_response(function_responses):
107+
yield event
108+
107109

108110
# Handle state changes
109111
if hasattr(adk_event, 'actions') and adk_event.actions and hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta:
@@ -281,6 +283,32 @@ async def _translate_function_calls(
281283
# Clean up tracking
282284
self._active_tool_calls.pop(tool_call_id, None)
283285

286+
287+
async def _translate_function_response(
288+
self,
289+
function_response: list[types.FunctionResponse],
290+
) -> AsyncGenerator[BaseEvent, None]:
291+
"""Translate function calls from ADK event to AG-UI tool call events.
292+
293+
Args:
294+
adk_event: The ADK event containing function calls
295+
function_response: List of function response from the event
296+
297+
Yields:
298+
Tool result events
299+
"""
300+
301+
for func_response in function_response:
302+
303+
tool_call_id = getattr(func_response, 'id', str(uuid.uuid4()))
304+
305+
yield ToolCallResultEvent(
306+
message_id=str(uuid.uuid4()),
307+
type=EventType.TOOL_CALL_RESULT,
308+
tool_call_id=tool_call_id,
309+
content=json.dumps(func_response.response)
310+
)
311+
284312
def _create_state_delta_event(
285313
self,
286314
state_delta: Dict[str, Any],

0 commit comments

Comments
 (0)