|
14 | 14 |
|
15 | 15 | """Unit tests for BaseLlmFlow toolset integration.""" |
16 | 16 |
|
17 | | -from typing import Optional |
| 17 | +from unittest import mock |
18 | 18 | from unittest.mock import AsyncMock |
19 | 19 |
|
20 | | -from google.adk.agents.callback_context import CallbackContext |
21 | 20 | from google.adk.agents.llm_agent import Agent |
22 | 21 | from google.adk.events.event import Event |
23 | 22 | from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow |
| 23 | +from google.adk.models.google_llm import Gemini |
24 | 24 | from google.adk.models.llm_request import LlmRequest |
25 | 25 | from google.adk.models.llm_response import LlmResponse |
26 | 26 | from google.adk.plugins.base_plugin import BasePlugin |
@@ -95,7 +95,6 @@ async def close(self): |
95 | 95 | async def test_preprocess_handles_mixed_tools_and_toolsets(): |
96 | 96 | """Test that _preprocess_async properly handles both tools and toolsets.""" |
97 | 97 | from google.adk.tools.base_tool import BaseTool |
98 | | - from google.adk.tools.function_tool import FunctionTool |
99 | 98 |
|
100 | 99 | # Create a mock tool |
101 | 100 | class _MockTool(BaseTool): |
@@ -200,6 +199,46 @@ def _my_tool(sides: int) -> int: |
200 | 199 | assert {d.name for d in declarations} == {'_my_tool', 'google_search_agent'} |
201 | 200 |
|
202 | 201 |
|
| 202 | +@pytest.mark.asyncio |
| 203 | +async def test_preprocess_calls_convert_tool_union_to_tools(): |
| 204 | + """Test that _preprocess_async calls _convert_tool_union_to_tools.""" |
| 205 | + |
| 206 | + class _MockTool: |
| 207 | + process_llm_request = AsyncMock() |
| 208 | + |
| 209 | + mock_tool_instance = _MockTool() |
| 210 | + |
| 211 | + def _my_tool(sides: int) -> int: |
| 212 | + """A simple tool.""" |
| 213 | + return sides |
| 214 | + |
| 215 | + with mock.patch( |
| 216 | + 'google.adk.agents.llm_agent._convert_tool_union_to_tools', |
| 217 | + new_callable=AsyncMock, |
| 218 | + ) as mock_convert: |
| 219 | + mock_convert.return_value = [mock_tool_instance] |
| 220 | + |
| 221 | + model = Gemini(model='gemini-2') |
| 222 | + agent = Agent( |
| 223 | + name='test_agent', model=model, tools=[_my_tool, google_search] |
| 224 | + ) |
| 225 | + invocation_context = await testing_utils.create_invocation_context( |
| 226 | + agent=agent, user_content='test message' |
| 227 | + ) |
| 228 | + flow = BaseLlmFlowForTesting() |
| 229 | + llm_request = LlmRequest(model='gemini-2') |
| 230 | + |
| 231 | + async for _ in flow._preprocess_async(invocation_context, llm_request): |
| 232 | + pass |
| 233 | + |
| 234 | + mock_convert.assert_called_with( |
| 235 | + google_search, |
| 236 | + mock.ANY, # ReadonlyContext(invocation_context) |
| 237 | + model, |
| 238 | + True, # multiple_tools |
| 239 | + ) |
| 240 | + |
| 241 | + |
203 | 242 | # TODO(b/448114567): Remove the following |
204 | 243 | # test_handle_after_model_callback_grounding tests once the workaround |
205 | 244 | # is no longer needed. |
|
0 commit comments