Skip to content

Commit d3148da

Browse files
xuanyang15copybara-github
authored andcommitted
ADK changes
PiperOrigin-RevId: 814319961
1 parent 2e2d61b commit d3148da

File tree

6 files changed

+622
-6
lines changed

6 files changed

+622
-6
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,22 @@
112112

113113

114114
async def _convert_tool_union_to_tools(
115-
tool_union: ToolUnion, ctx: ReadonlyContext
115+
tool_union: ToolUnion,
116+
ctx: ReadonlyContext,
117+
model: Union[str, BaseLlm],
118+
multiple_tools: bool = False,
116119
) -> list[BaseTool]:
120+
from ..tools.google_search_tool import google_search
121+
122+
# Wrap google_search tool with AgentTool if there are multiple tools because
123+
# the built-in tools cannot be used together with other tools.
124+
# TODO(b/448114567): Remove once the workaround is no longer needed.
125+
if multiple_tools and tool_union is google_search:
126+
from ..tools.google_search_agent_tool import create_google_search_agent
127+
from ..tools.google_search_agent_tool import GoogleSearchAgentTool
128+
129+
return [GoogleSearchAgentTool(create_google_search_agent(model))]
130+
117131
if isinstance(tool_union, BaseTool):
118132
return [tool_union]
119133
if callable(tool_union):
@@ -462,8 +476,16 @@ async def canonical_tools(
462476
This method is only for use by Agent Development Kit.
463477
"""
464478
resolved_tools = []
479+
# We may need to wrap some built-in tools if there are other tools
480+
# because the built-in tools cannot be used together with other tools.
481+
# TODO(b/448114567): Remove once the workaround is no longer needed.
482+
multiple_tools = len(self.tools) > 1
465483
for tool_union in self.tools:
466-
resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx))
484+
resolved_tools.extend(
485+
await _convert_tool_union_to_tools(
486+
tool_union, ctx, self.model, multiple_tools
487+
)
488+
)
467489
return resolved_tools
468490

469491
@property

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ...telemetry.tracing import trace_send_data
4646
from ...telemetry.tracing import tracer
4747
from ...tools.base_toolset import BaseToolset
48+
from ...tools.google_search_tool import google_search
4849
from ...tools.tool_context import ToolContext
4950
from ...utils.context_utils import Aclosing
5051
from .audio_cache_manager import AudioCacheManager
@@ -442,6 +443,11 @@ async def _preprocess_async(
442443
yield event
443444

444445
# Run processors for tools.
446+
447+
# We may need to wrap some built-in tools if there are other tools
448+
# because the built-in tools cannot be used together with other tools.
449+
# TODO(b/448114567): Remove once the workaround is no longer needed.
450+
multiple_tools = len(agent.tools) > 1
445451
for tool_union in agent.tools:
446452
tool_context = ToolContext(invocation_context)
447453

@@ -455,7 +461,10 @@ async def _preprocess_async(
455461

456462
# Then process all tools from this tool union
457463
tools = await _convert_tool_union_to_tools(
458-
tool_union, ReadonlyContext(invocation_context)
464+
tool_union,
465+
ReadonlyContext(invocation_context),
466+
llm_request.model,
467+
multiple_tools,
459468
)
460469
for tool in tools:
461470
await tool.process_llm_request(
@@ -818,6 +827,26 @@ async def _handle_after_model_callback(
818827

819828
agent = invocation_context.agent
820829

830+
# Add grounding metadata to the response if needed.
831+
# TODO(b/448114567): Remove this function once the workaround is no longer needed.
832+
async def _maybe_add_grounding_metadata(
833+
response: Optional[LlmResponse] = None,
834+
) -> Optional[LlmResponse]:
835+
readonly_context = ReadonlyContext(invocation_context)
836+
tools = await agent.canonical_tools(readonly_context)
837+
if not any(tool.name == 'google_search_agent' for tool in tools):
838+
return response
839+
ground_metadata = invocation_context.session.state.get(
840+
'temp:_adk_grounding_metadata', None
841+
)
842+
if not ground_metadata:
843+
return response
844+
845+
if not response:
846+
response = llm_response
847+
response.grounding_metadata = ground_metadata
848+
return response
849+
821850
callback_context = CallbackContext(
822851
invocation_context, event_actions=model_response_event.actions
823852
)
@@ -830,20 +859,21 @@ async def _handle_after_model_callback(
830859
)
831860
)
832861
if callback_response:
833-
return callback_response
862+
return await _maybe_add_grounding_metadata(callback_response)
834863

835864
# If no overrides are provided from the plugins, further run the canonical
836865
# callbacks.
837866
if not agent.canonical_after_model_callbacks:
838-
return
867+
return await _maybe_add_grounding_metadata()
839868
for callback in agent.canonical_after_model_callbacks:
840869
callback_response = callback(
841870
callback_context=callback_context, llm_response=llm_response
842871
)
843872
if inspect.isawaitable(callback_response):
844873
callback_response = await callback_response
845874
if callback_response:
846-
return callback_response
875+
return await _maybe_add_grounding_metadata(callback_response)
876+
return await _maybe_add_grounding_metadata()
847877

848878
def _finalize_model_response_event(
849879
self,
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
from typing import Union
19+
20+
from google.genai import types
21+
from typing_extensions import override
22+
23+
from ..agents.llm_agent import LlmAgent
24+
from ..memory.in_memory_memory_service import InMemoryMemoryService
25+
from ..models.base_llm import BaseLlm
26+
from ..utils.context_utils import Aclosing
27+
from ._forwarding_artifact_service import ForwardingArtifactService
28+
from .agent_tool import AgentTool
29+
from .google_search_tool import google_search
30+
from .tool_context import ToolContext
31+
32+
33+
def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
34+
"""Create a sub-agent that only uses google_search tool."""
35+
return LlmAgent(
36+
name='google_search_agent',
37+
model=model,
38+
description=(
39+
'An agent for performing Google search using the `google_search` tool'
40+
),
41+
instruction="""
42+
You are a specialized Google search agent.
43+
44+
When given a search query, use the `google_search` tool to find the related information.
45+
""",
46+
tools=[google_search],
47+
)
48+
49+
50+
class GoogleSearchAgentTool(AgentTool):
51+
"""A tool that wraps a sub-agent that only uses google_search tool.
52+
53+
This is a workaround to support using google_search tool with other tools.
54+
TODO(b/448114567): Remove once the workaround is no longer needed.
55+
56+
Attributes:
57+
model: The model to use for the sub-agent.
58+
"""
59+
60+
def __init__(self, agent: LlmAgent):
61+
self.agent = agent
62+
super().__init__(agent=self.agent)
63+
64+
@override
65+
async def run_async(
66+
self,
67+
*,
68+
args: dict[str, Any],
69+
tool_context: ToolContext,
70+
) -> Any:
71+
from ..agents.llm_agent import LlmAgent
72+
from ..runners import Runner
73+
from ..sessions.in_memory_session_service import InMemorySessionService
74+
75+
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
76+
input_value = self.agent.input_schema.model_validate(args)
77+
content = types.Content(
78+
role='user',
79+
parts=[
80+
types.Part.from_text(
81+
text=input_value.model_dump_json(exclude_none=True)
82+
)
83+
],
84+
)
85+
else:
86+
content = types.Content(
87+
role='user',
88+
parts=[types.Part.from_text(text=args['request'])],
89+
)
90+
runner = Runner(
91+
app_name=self.agent.name,
92+
agent=self.agent,
93+
artifact_service=ForwardingArtifactService(tool_context),
94+
session_service=InMemorySessionService(),
95+
memory_service=InMemoryMemoryService(),
96+
credential_service=tool_context._invocation_context.credential_service,
97+
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
98+
)
99+
100+
state_dict = {
101+
k: v
102+
for k, v in tool_context.state.to_dict().items()
103+
if not k.startswith('_adk') # Filter out adk internal states
104+
}
105+
session = await runner.session_service.create_session(
106+
app_name=self.agent.name,
107+
user_id=tool_context._invocation_context.user_id,
108+
state=state_dict,
109+
)
110+
111+
last_content = None
112+
last_grounding_metadata = None
113+
async with Aclosing(
114+
runner.run_async(
115+
user_id=session.user_id, session_id=session.id, new_message=content
116+
)
117+
) as agen:
118+
async for event in agen:
119+
# Forward state delta to parent session.
120+
if event.actions.state_delta:
121+
tool_context.state.update(event.actions.state_delta)
122+
if event.content:
123+
last_content = event.content
124+
last_grounding_metadata = event.grounding_metadata
125+
126+
if not last_content:
127+
return ''
128+
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
129+
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
130+
tool_result = self.agent.output_schema.model_validate_json(
131+
merged_text
132+
).model_dump(exclude_none=True)
133+
else:
134+
tool_result = merged_text
135+
136+
if last_grounding_metadata:
137+
tool_context.state['temp:_adk_grounding_metadata'] = (
138+
last_grounding_metadata
139+
)
140+
return tool_result

tests/unittests/agents/test_llm_agent_fields.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.adk.models.llm_request import LlmRequest
2727
from google.adk.models.registry import LLMRegistry
2828
from google.adk.sessions.in_memory_session_service import InMemorySessionService
29+
from google.adk.tools.google_search_tool import google_search
2930
from google.genai import types
3031
from pydantic import BaseModel
3132
import pytest
@@ -279,3 +280,63 @@ def test_allow_transfer_by_default():
279280

280281
assert not agent.disallow_transfer_to_parent
281282
assert not agent.disallow_transfer_to_peers
283+
284+
285+
# TODO(b/448114567): Remove TestCanonicalTools once the workaround
286+
# is no longer needed.
287+
class TestCanonicalTools:
288+
"""Unit tests for canonical_tools in LlmAgent."""
289+
290+
@staticmethod
291+
def _my_tool(sides: int) -> int:
292+
return sides
293+
294+
async def test_handle_google_search_with_other_tools(self):
295+
"""Test that google_search is wrapped into an agent."""
296+
agent = LlmAgent(
297+
name='test_agent',
298+
model='gemini-pro',
299+
tools=[
300+
self._my_tool,
301+
google_search,
302+
],
303+
)
304+
ctx = await _create_readonly_context(agent)
305+
tools = await agent.canonical_tools(ctx)
306+
307+
assert len(tools) == 2
308+
assert tools[0].name == '_my_tool'
309+
assert tools[1].name == 'google_search_agent'
310+
assert tools[1].__class__.__name__ == 'GoogleSearchAgentTool'
311+
312+
async def test_handle_google_search_only(self):
313+
"""Test that google_search is not wrapped into an agent."""
314+
agent = LlmAgent(
315+
name='test_agent',
316+
model='gemini-pro',
317+
tools=[
318+
google_search,
319+
],
320+
)
321+
ctx = await _create_readonly_context(agent)
322+
tools = await agent.canonical_tools(ctx)
323+
324+
assert len(tools) == 1
325+
assert tools[0].name == 'google_search'
326+
assert tools[0].__class__.__name__ == 'GoogleSearchTool'
327+
328+
async def test_no_google_search(self):
329+
"""Test other tools are not affected."""
330+
agent = LlmAgent(
331+
name='test_agent',
332+
model='gemini-pro',
333+
tools=[
334+
self._my_tool,
335+
],
336+
)
337+
ctx = await _create_readonly_context(agent)
338+
tools = await agent.canonical_tools(ctx)
339+
340+
assert len(tools) == 1
341+
assert tools[0].name == '_my_tool'
342+
assert tools[0].__class__.__name__ == 'FunctionTool'

0 commit comments

Comments
 (0)