Skip to content

Commit bec1214

Browse files
committed
Validate tool call arguments
1 parent ab3a99a commit bec1214

File tree

5 files changed

+79
-30
lines changed

5 files changed

+79
-30
lines changed

agents/agents/tools/src/react/tools.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ...tool_base import tool
22

33
@tool(name="answer", description="Give the final answer. The answer should be put inside the \\boxed{} tag.", status="finish")
4-
def answer_math(answer: str):
4+
def answer(answer: str):
55
"""
66
A helper tool to give the final answer. The answer should be put inside the \\boxed{} tag.
77
Args:
@@ -11,8 +11,18 @@ def answer_math(answer: str):
1111
"""
1212
return str(answer)
1313

14+
@tool(name="answer_math", description="Give the final answer. The answer should be put inside the \\boxed{} tag.", status="finish")
15+
def answer_math(answer: str):
16+
"""
17+
A helper tool to give the final answer. The answer should be put inside the \\boxed{} tag.
18+
Args:
19+
answer (str): The final answer to the question.
20+
Returns:
21+
str: The final answer to the question.
22+
"""
23+
return str(answer)
1424

15-
@tool(name="answer", description="Give the final answer. The answer should be a simple, short, and direct.", status="finish")
25+
@tool(name="answer_qa", description="Give the final answer. The answer should be a simple, short, and direct.", status="finish")
1626
def answer_qa(answer: str):
1727
"""
1828
A helper tool to give the final answer. The answer should be a simple, short, and direct.

agents/agents/tools/tool_base.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
name: str | None = None,
4343
description: str | None = None,
4444
schema: dict | None = None,
45+
args: dict | None = None,
4546
max_length: int = 2048,
4647
env_cls: type[BaseEnv] | None = None,
4748
env_kwargs: dict | None = None,
@@ -54,6 +55,7 @@ def __init__(
5455
self.name = name or func.__name__
5556
self.description = description or ""
5657
self.schema = schema
58+
self.args = args
5759
self.max_length = max_length
5860
self.status = status
5961

@@ -87,6 +89,14 @@ def used_env_size(self):
8789
return len(self._envs)
8890
return 0
8991

92+
def _validate_call_args(self, kwargs):
93+
# TODO: raise error, return error message, or filter the invalid arguments, make it configurable. Currently, we just return the error message.
94+
for arg in kwargs:
95+
if arg not in self.args:
96+
# raise ValueError(f"""Invalid argument "{arg}" for tool {self.name}.""")
97+
result = f"""Invalid argument "{arg}" for tool {self.name}."""
98+
return result
99+
return None
90100

91101
async def __call__(self, **kwargs):
92102
"""
@@ -101,30 +111,38 @@ async def __call__(self, **kwargs):
101111
- "status": The status of the tool call.
102112
- "info": The info of the tool call.
103113
"""
104-
try:
105-
if not self.is_stateful:
106-
# For non-stateful tools, directly execute the function
107-
result = await self.user_func(**kwargs) if inspect.iscoroutinefunction(self.user_func) \
108-
else self.user_func(**kwargs)
109-
else:
110-
# For stateful tools, handle environment management
111-
id = kwargs.pop('id', None)
112-
if id is None:
113-
result = "Error: 'id' parameter is required for stateful tools"
114+
# Check arguments before calling the tool
115+
result = self._validate_call_args(kwargs)
116+
117+
# If the arguments are valid, call the tool
118+
if result is None:
119+
try:
120+
if not self.is_stateful:
121+
# For non-stateful tools, directly execute the function
122+
result = await self.user_func(**kwargs) if inspect.iscoroutinefunction(self.user_func) \
123+
else self.user_func(**kwargs)
114124
else:
115-
await self._initialize_envs()
116-
env = await self._acquire_env(id)
117-
118-
async with self._locks[id]:
119-
# token = current_env.set(env)
120-
assert kwargs.get("env", None) is None, "env is not allowed to be passed to stateful tools"
121-
try:
122-
result = await self.user_func(env=env,**kwargs) if inspect.iscoroutinefunction(self.user_func) \
123-
else self.user_func(**kwargs)
124-
finally:
125-
pass
126-
except Exception as e:
127-
result = str(e)
125+
# For stateful tools, handle environment management
126+
id = kwargs.pop('id', None)
127+
if id is None:
128+
result = "Error: 'id' parameter is required for stateful tools"
129+
else:
130+
await self._initialize_envs()
131+
env = await self._acquire_env(id)
132+
133+
async with self._locks[id]:
134+
# token = current_env.set(env)
135+
assert kwargs.get("env", None) is None, "env is not allowed to be passed to stateful tools"
136+
try:
137+
result = await self.user_func(env=env,**kwargs) if inspect.iscoroutinefunction(self.user_func) \
138+
else self.user_func(**kwargs)
139+
finally:
140+
pass
141+
except Exception as e:
142+
result = str(e)
143+
# If the arguments are invalid, simply use the result from the validation
144+
else:
145+
pass
128146

129147
# Result must be a string or a dict
130148
if isinstance(result, str):
@@ -258,15 +276,16 @@ def decorator(func):
258276
signature = extract_signatures(func)
259277
docs = parse_docstring(inspect.getdoc(func))
260278
final_desc = description or docs.get("summary", "")
261-
final_schema = validate_schema(final_name, final_desc, signature, docs)
279+
validated_schema = validate_schema(final_name, final_desc, signature, docs)
262280

263281
# Create the tool
264282
def factory():
265283
return Tool(
266284
func=func,
267285
name=final_name,
268286
description=final_desc,
269-
schema=final_schema,
287+
schema=validated_schema["schema"],
288+
args=validated_schema["args"],
270289
max_length=max_length,
271290
env_cls=env_cls,
272291
env_kwargs=env_kwargs,

agents/agents/tools/utils/schema.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
import inspect
33
import warnings
4-
4+
from copy import deepcopy
55

66
def extract_signatures(func):
77
sig = inspect.signature(func)
@@ -185,7 +185,14 @@ def validate_schema(name, description, signature, docs):
185185
}
186186
}
187187
}
188-
return schema
188+
arguments = deepcopy(signature)
189+
if "env" in arguments:
190+
del arguments["env"]
191+
192+
return {
193+
"schema": schema,
194+
"args": arguments
195+
}
189196

190197
if __name__ == '__main__':
191198
# Retrieve and parse the docstring using inspect.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
from agents.tools import tool
3+
4+
5+
@pytest.mark.asyncio
6+
async def test_args_validation():
7+
@tool(name="add", description="Adds two numbers.")
8+
async def add(a, b):
9+
return a + b
10+
11+
result = await add(a=1, b=2, c=3)
12+
assert result == {"name": "add", "arguments": {"a": 1, "b": 2, "c": 3}, "observation": "Invalid argument \"c\" for tool add.", "status": "success", "info": {}}
13+

0 commit comments

Comments
 (0)