Skip to content

Commit 1ad7973

Browse files
hwchase17jxnlJason Liu
authored
Harrison/tool decorator (#790)
Co-authored-by: Jason Liu <[email protected]> Co-authored-by: Jason Liu <[email protected]>
1 parent 5f73d06 commit 1ad7973

File tree

4 files changed

+226
-5
lines changed

4 files changed

+226
-5
lines changed

docs/modules/agents/examples/custom_tools.ipynb

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. A Tool is defined as below.\n",
1111
"\n",
1212
"```python\n",
13-
"class Tool(NamedTuple):\n",
13+
"@dataclass \n",
14+
"class Tool:\n",
1415
" \"\"\"Interface for tools.\"\"\"\n",
1516
"\n",
1617
" name: str\n",
1718
" func: Callable[[str], str]\n",
1819
" description: Optional[str] = None\n",
20+
" return_direct: bool = True\n",
1921
"```\n",
2022
"\n",
21-
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all."
23+
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all. You can create these tools directly, but we also provide a decorator to easily convert any function into a tool."
2224
]
2325
},
2426
{
@@ -151,6 +153,94 @@
151153
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
152154
]
153155
},
156+
{
157+
"cell_type": "markdown",
158+
"id": "824eaf74",
159+
"metadata": {},
160+
"source": [
161+
"## Using the `tool` decorator\n",
162+
"\n",
163+
"To make it easier to define custom tools, a `@tool` decorator is provided. This decorator can be used to quickly create a `Tool` from a simple function. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description."
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 1,
169+
"id": "8f15307d",
170+
"metadata": {},
171+
"outputs": [],
172+
"source": [
173+
"from langchain.agents import tool\n",
174+
"\n",
175+
"@tool\n",
176+
"def search_api(query: str) -> str:\n",
177+
" \"\"\"Searches the API for the query.\"\"\"\n",
178+
" return \"Results\""
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": 2,
184+
"id": "0a23b91b",
185+
"metadata": {},
186+
"outputs": [
187+
{
188+
"data": {
189+
"text/plain": [
190+
"Tool(name='search_api', func=<function search_api at 0x10dad7d90>, description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False)"
191+
]
192+
},
193+
"execution_count": 2,
194+
"metadata": {},
195+
"output_type": "execute_result"
196+
}
197+
],
198+
"source": [
199+
"search_api"
200+
]
201+
},
202+
{
203+
"cell_type": "markdown",
204+
"id": "cc6ee8c1",
205+
"metadata": {},
206+
"source": [
207+
"You can also provide arguments like the tool name and whether to return directly."
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": 3,
213+
"id": "28cdf04d",
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"@tool(\"search\", return_direct=True)\n",
218+
"def search_api(query: str) -> str:\n",
219+
" \"\"\"Searches the API for the query.\"\"\"\n",
220+
" return \"Results\""
221+
]
222+
},
223+
{
224+
"cell_type": "code",
225+
"execution_count": 4,
226+
"id": "1085a4bd",
227+
"metadata": {},
228+
"outputs": [
229+
{
230+
"data": {
231+
"text/plain": [
232+
"Tool(name='search', func=<function search_api at 0x112301bd0>, description='search(query: str) -> str - Searches the API for the query.', return_direct=True)"
233+
]
234+
},
235+
"execution_count": 4,
236+
"metadata": {},
237+
"output_type": "execute_result"
238+
}
239+
],
240+
"source": [
241+
"search_api"
242+
]
243+
},
154244
{
155245
"cell_type": "markdown",
156246
"id": "1d0430d6",
@@ -432,7 +522,7 @@
432522
},
433523
"vscode": {
434524
"interpreter": {
435-
"hash": "cb23c3a7a387ab03496baa08507270f8e0861b23170e79d5edc545893cdca840"
525+
"hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34"
436526
}
437527
}
438528
},

langchain/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
88
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
99
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
10-
from langchain.agents.tools import Tool
10+
from langchain.agents.tools import Tool, tool
1111

1212
__all__ = [
1313
"MRKLChain",
@@ -16,6 +16,7 @@
1616
"AgentExecutor",
1717
"Agent",
1818
"Tool",
19+
"tool",
1920
"initialize_agent",
2021
"ZeroShotAgent",
2122
"ReActTextWorldAgent",

langchain/agents/tools.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Interface for tools."""
22
from dataclasses import dataclass
3-
from typing import Callable, Optional
3+
from inspect import signature
4+
from typing import Any, Callable, Optional, Union
45

56

67
@dataclass
@@ -11,3 +12,65 @@ class Tool:
1112
func: Callable[[str], str]
1213
description: Optional[str] = None
1314
return_direct: bool = False
15+
16+
def __call__(self, *args: Any, **kwargs: Any) -> str:
17+
"""Make tools callable by piping through to `func`."""
18+
return self.func(*args, **kwargs)
19+
20+
21+
def tool(
22+
*args: Union[str, Callable], return_direct: bool = False
23+
) -> Union[Callable, Tool]:
24+
"""Make tools out of functions, can be used with or without arguments.
25+
26+
Requires:
27+
- Function must be of type (str) -> str
28+
- Function must have a docstring
29+
30+
Examples:
31+
.. code-block:: python
32+
33+
@tool
34+
def search_api(query: str) -> str:
35+
# Searches the API for the query.
36+
return
37+
38+
@tool("search", return_direct=True)
39+
def search_api(query: str) -> str:
40+
# Searches the API for the query.
41+
return
42+
"""
43+
44+
def _make_with_name(tool_name: str) -> Callable:
45+
def _make_tool(func: Callable[[str], str]) -> Tool:
46+
assert func.__doc__, "Function must have a docstring"
47+
# Description example:
48+
# search_api(query: str) - Searches the API for the query.
49+
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
50+
tool = Tool(
51+
name=tool_name,
52+
func=func,
53+
description=description,
54+
return_direct=return_direct,
55+
)
56+
return tool
57+
58+
return _make_tool
59+
60+
if len(args) == 1 and isinstance(args[0], str):
61+
# if the argument is a string, then we use the string as the tool name
62+
# Example usage: @tool("search", return_direct=True)
63+
return _make_with_name(args[0])
64+
elif len(args) == 1 and callable(args[0]):
65+
# if the argument is a function, then we use the function name as the tool name
66+
# Example usage: @tool
67+
return _make_with_name(args[0].__name__)(args[0])
68+
elif len(args) == 0:
69+
# if there are no arguments, then we use the function name as the tool name
70+
# Example usage: @tool(return_direct=True)
71+
def _partial(func: Callable[[str], str]) -> Tool:
72+
return _make_with_name(func.__name__)(func)
73+
74+
return _partial
75+
else:
76+
raise ValueError("Too many arguments for tool decorator")

tests/unit_tests/agents/test_tools.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Test tool utils."""
2+
import pytest
3+
4+
from langchain.agents.tools import Tool, tool
5+
6+
7+
def test_unnamed_decorator() -> None:
8+
"""Test functionality with unnamed decorator."""
9+
10+
@tool
11+
def search_api(query: str) -> str:
12+
"""Search the API for the query."""
13+
return "API result"
14+
15+
assert isinstance(search_api, Tool)
16+
assert search_api.name == "search_api"
17+
assert not search_api.return_direct
18+
assert search_api("test") == "API result"
19+
20+
21+
def test_named_tool_decorator() -> None:
22+
"""Test functionality when arguments are provided as input to decorator."""
23+
24+
@tool("search")
25+
def search_api(query: str) -> str:
26+
"""Search the API for the query."""
27+
return "API result"
28+
29+
assert isinstance(search_api, Tool)
30+
assert search_api.name == "search"
31+
assert not search_api.return_direct
32+
33+
34+
def test_named_tool_decorator_return_direct() -> None:
35+
"""Test functionality when arguments and return direct are provided as input."""
36+
37+
@tool("search", return_direct=True)
38+
def search_api(query: str) -> str:
39+
"""Search the API for the query."""
40+
return "API result"
41+
42+
assert isinstance(search_api, Tool)
43+
assert search_api.name == "search"
44+
assert search_api.return_direct
45+
46+
47+
def test_unnamed_tool_decorator_return_direct() -> None:
48+
"""Test functionality when only return direct is provided."""
49+
50+
@tool(return_direct=True)
51+
def search_api(query: str) -> str:
52+
"""Search the API for the query."""
53+
return "API result"
54+
55+
assert isinstance(search_api, Tool)
56+
assert search_api.name == "search_api"
57+
assert search_api.return_direct
58+
59+
60+
def test_missing_docstring() -> None:
61+
"""Test error is raised when docstring is missing."""
62+
# expect to throw a value error if theres no docstring
63+
with pytest.raises(AssertionError):
64+
65+
@tool
66+
def search_api(query: str) -> str:
67+
return "API result"

0 commit comments

Comments
 (0)