Skip to content

Commit f51380f

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Extend ReflectAndRetryToolPlugin to support hallucinating function calls
PiperOrigin-RevId: 820051762
1 parent 3734cea commit f51380f

File tree

6 files changed

+283
-28
lines changed

6 files changed

+283
-28
lines changed

contributing/samples/plugin_reflect_tool_retry/README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,30 @@ You can run the agent with:
4646
$ adk web contributing/samples/plugin_reflect_tool_retry
4747
```
4848

49-
You can provide the following prompt to see the agent retrying tool calls:
49+
Select "basic" and provide the following prompt to see the agent retrying tool
50+
calls:
5051

5152
```
5253
Please guess a number! Tell me what number you guess and how is it.
5354
```
55+
56+
### Hallucinating tool calls
57+
58+
The "hallucinating_func_name" agent is an example to show the plugin can retry
59+
hallucinating tool calls.
60+
61+
For example, we used the `after_model_callback` to hack a tool call with the
62+
wrong name then the agent can retry calling with the right tool name.
63+
64+
You can run the agent with:
65+
66+
```bash
67+
$ adk web contributing/samples/plugin_reflect_tool_retry
68+
```
69+
70+
Select "hallucinating_func_name" and provide the following prompt to see the
71+
agent retrying tool calls:
72+
73+
```
74+
Roll a 6 sided die
75+
```
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
17+
from google.adk.agents import LlmAgent
18+
from google.adk.agents.callback_context import CallbackContext
19+
from google.adk.apps.app import App
20+
from google.adk.models.llm_response import LlmResponse
21+
from google.adk.plugins import ReflectAndRetryToolPlugin
22+
from google.adk.tools.tool_context import ToolContext
23+
24+
APP_NAME = "hallucinating_func_name"
25+
USER_ID = "test_user"
26+
27+
hallucinated = False # Whether the tool name is hallucinated
28+
29+
30+
def roll_die(sides: int, tool_context: ToolContext) -> int:
31+
"""Roll a die and return the rolled result.
32+
33+
Args:
34+
sides: The integer number of sides the die has.
35+
36+
Returns:
37+
An integer of the result of rolling the die.
38+
"""
39+
result = random.randint(1, sides)
40+
if not "rolls" in tool_context.state:
41+
tool_context.state["rolls"] = []
42+
43+
tool_context.state["rolls"] = tool_context.state["rolls"] + [result]
44+
return result
45+
46+
47+
def after_model_callback(
48+
callback_context: CallbackContext, llm_response: LlmResponse
49+
):
50+
"""After model callback to produce one hallucinating tool call."""
51+
global hallucinated
52+
53+
if hallucinated:
54+
return None
55+
56+
if (
57+
llm_response.content
58+
and llm_response.content.parts[0].function_call.name == "roll_die"
59+
):
60+
llm_response.content.parts[0].function_call.name = "roll_die_wrong_name"
61+
hallucinated = True
62+
return None
63+
64+
65+
root_agent = LlmAgent(
66+
name="hello_world",
67+
description="Helpful agent",
68+
instruction="""Use guess_number_tool to guess a number.""",
69+
model="gemini-2.5-flash",
70+
tools=[roll_die],
71+
after_model_callback=after_model_callback,
72+
)
73+
74+
75+
app = App(
76+
name=APP_NAME,
77+
root_agent=root_agent,
78+
plugins=[
79+
ReflectAndRetryToolPlugin(max_retries=3),
80+
],
81+
)

src/google/adk/flows/llm_flows/functions.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -275,21 +275,37 @@ async def _execute_single_function_call_async(
275275
tool_confirmation: Optional[ToolConfirmation] = None,
276276
) -> Optional[Event]:
277277
"""Execute a single function call with thread safety for state modifications."""
278-
tool, tool_context = _get_tool_and_context(
279-
invocation_context,
280-
function_call,
281-
tools_dict,
282-
tool_confirmation,
278+
# Do not use "args" as the variable name, because it is a reserved keyword
279+
# in python debugger.
280+
# Make a deep copy to avoid being modified.
281+
function_args = (
282+
copy.deepcopy(function_call.args) if function_call.args else {}
283283
)
284284

285-
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
286-
# Do not use "args" as the variable name, because it is a reserved keyword
287-
# in python debugger.
288-
# Make a deep copy to avoid being modified.
289-
function_args = (
290-
copy.deepcopy(function_call.args) if function_call.args else {}
285+
tool_context = _create_tool_context(
286+
invocation_context, function_call, tool_confirmation
287+
)
288+
289+
try:
290+
tool = _get_tool(function_call, tools_dict)
291+
except ValueError as tool_error:
292+
tool = BaseTool(name=function_call.name, description='Tool not found')
293+
error_response = (
294+
await invocation_context.plugin_manager.run_on_tool_error_callback(
295+
tool=tool,
296+
tool_args=function_args,
297+
tool_context=tool_context,
298+
error=tool_error,
299+
)
291300
)
301+
if error_response is not None:
302+
return __build_response_event(
303+
tool, error_response, tool_context, invocation_context
304+
)
305+
else:
306+
raise tool_error
292307

308+
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
293309
# Step 1: Check if plugin before_tool_callback overrides the function
294310
# response.
295311
function_response = (
@@ -639,24 +655,45 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
639655
return function_response
640656

641657

642-
def _get_tool_and_context(
643-
invocation_context: InvocationContext,
644-
function_call: types.FunctionCall,
645-
tools_dict: dict[str, BaseTool],
646-
tool_confirmation: Optional[ToolConfirmation] = None,
658+
def _get_tool(
659+
function_call: types.FunctionCall, tools_dict: dict[str, BaseTool]
647660
):
661+
"""Returns the tool corresponding to the function call."""
648662
if function_call.name not in tools_dict:
649663
raise ValueError(
650-
f'Function {function_call.name} is not found in the tools_dict.'
664+
f'Function {function_call.name} is not found in the tools_dict:'
665+
f' {tools_dict.keys()}.'
651666
)
652667

653-
tool_context = ToolContext(
668+
return tools_dict[function_call.name]
669+
670+
671+
def _create_tool_context(
672+
invocation_context: InvocationContext,
673+
function_call: types.FunctionCall,
674+
tool_confirmation: Optional[ToolConfirmation] = None,
675+
):
676+
"""Creates a ToolContext object."""
677+
return ToolContext(
654678
invocation_context=invocation_context,
655679
function_call_id=function_call.id,
656680
tool_confirmation=tool_confirmation,
657681
)
658682

659-
tool = tools_dict[function_call.name]
683+
684+
def _get_tool_and_context(
685+
invocation_context: InvocationContext,
686+
function_call: types.FunctionCall,
687+
tools_dict: dict[str, BaseTool],
688+
tool_confirmation: Optional[ToolConfirmation] = None,
689+
):
690+
"""Returns the tool and tool context corresponding to the function call."""
691+
tool = _get_tool(function_call, tools_dict)
692+
tool_context = _create_tool_context(
693+
invocation_context,
694+
function_call,
695+
tool_confirmation,
696+
)
660697

661698
return (tool, tool_context)
662699

src/google/adk/plugins/reflect_retry_tool_plugin.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,20 @@ async def after_tool_callback(
142142
tool_args: dict[str, Any],
143143
tool_context: ToolContext,
144144
result: Any,
145-
) -> Optional[dict]:
146-
"""Handles successful tool calls or extracts and processes errors."""
145+
) -> Optional[dict[str, Any]]:
146+
"""Handles successful tool calls or extracts and processes errors.
147+
148+
Args:
149+
tool: The tool that was called.
150+
tool_args: The arguments passed to the tool.
151+
tool_context: The context of the tool call.
152+
result: The result of the tool call.
153+
154+
Returns:
155+
An optional dictionary containing reflection guidance if an error is
156+
detected, or None if the tool call was successful or the
157+
response is already a reflection message.
158+
"""
147159
if (
148160
isinstance(result, dict)
149161
and result.get("response_type") == REFLECT_AND_RETRY_RESPONSE_TYPE
@@ -157,7 +169,8 @@ async def after_tool_callback(
157169
if error:
158170
return await self._handle_tool_error(tool, tool_args, tool_context, error)
159171

160-
# On success, reset the failure count for this specific tool within its scope.
172+
# On success, reset the failure count for this specific tool within
173+
# its scope.
161174
await self._reset_failures_for_tool(tool_context, tool.name)
162175
return None
163176

@@ -168,14 +181,23 @@ async def extract_error_from_result(
168181
tool_args: dict[str, Any],
169182
tool_context: ToolContext,
170183
result: Any,
171-
) -> Optional[Any]:
184+
) -> Optional[dict[str, Any]]:
172185
"""Extracts an error from a successful tool result and triggers retry logic.
173186
174187
This is useful when tool call finishes successfully but the result contains
175188
an error object like {"error": ...} that should be handled by the plugin.
176189
177190
By overriding this method, you can trigger retry logic on these successful
178191
results that contain errors.
192+
193+
Args:
194+
tool: The tool that was called.
195+
tool_args: The arguments passed to the tool.
196+
tool_context: The context of the tool call.
197+
result: The result of the tool call.
198+
199+
Returns:
200+
The extracted error if any, or None if no error was detected.
179201
"""
180202
return None
181203

@@ -186,8 +208,18 @@ async def on_tool_error_callback(
186208
tool_args: dict[str, Any],
187209
tool_context: ToolContext,
188210
error: Exception,
189-
) -> Optional[dict]:
190-
"""Handles tool exceptions by providing reflection guidance."""
211+
) -> Optional[dict[str, Any]]:
212+
"""Handles tool exceptions by providing reflection guidance.
213+
214+
Args:
215+
tool: The tool that was called.
216+
tool_args: The arguments passed to the tool.
217+
tool_context: The context of the tool call.
218+
error: The exception raised by the tool.
219+
220+
Returns:
221+
An optional dictionary containing reflection guidance for the error.
222+
"""
191223
return await self._handle_tool_error(tool, tool_args, tool_context, error)
192224

193225
async def _handle_tool_error(
@@ -196,8 +228,18 @@ async def _handle_tool_error(
196228
tool_args: dict[str, Any],
197229
tool_context: ToolContext,
198230
error: Any,
199-
) -> Optional[dict]:
200-
"""Central, thread-safe logic for processing tool errors."""
231+
) -> Optional[dict[str, Any]]:
232+
"""Central, thread-safe logic for processing tool errors.
233+
234+
Args:
235+
tool: The tool that was called.
236+
tool_args: The arguments passed to the tool.
237+
tool_context: The context of the tool call.
238+
error: The error to be handled.
239+
240+
Returns:
241+
An optional dictionary containing reflection guidance for the error.
242+
"""
201243
if self.max_retries == 0:
202244
if self.throw_exception_if_retry_exceeded:
203245
raise error
@@ -285,6 +327,7 @@ def _create_tool_reflection_response(
285327
2. **State or Preconditions**: Did a previous step fail or not produce the necessary state/resource for this tool to succeed?
286328
3. **Alternative Approach**: Is this the right tool for the job? Could another tool or a different sequence of steps achieve the goal?
287329
4. **Simplify the Task**: Can you break the problem down into smaller, simpler steps?
330+
5. **Wrong Function Name**: Does the error indicates the tool is not found? Please check again and only use available tools.
288331
289332
Formulate a new plan based on your analysis and try a corrected or different approach.
290333
"""

0 commit comments

Comments
 (0)