Skip to content

Commit 6c55f63

Browse files
Added toolkit models (#1382)
* Added toolkit models. * Fixes 3.9 compatibility * Updated example in docs * Removed ToolSet, unique name requirement * Updated method name * Fixed hash, added None as default, updated logic. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f66e4d0 commit 6c55f63

File tree

3 files changed

+460
-0
lines changed

3 files changed

+460
-0
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""Tests for toolkit"""
2+
3+
from jupyter_ai.tools import Tool, Toolkit
4+
5+
6+
def sample_function():
7+
"""This is a sample function for testing."""
8+
return "Hello, World!"
9+
10+
11+
def another_function(name: str):
12+
"""Greet someone by name.
13+
14+
This function takes a name and returns a greeting.
15+
"""
16+
return f"Hello, {name}!"
17+
18+
19+
def no_doc_function():
20+
return "No documentation"
21+
22+
23+
class TestTool:
24+
"""Test the Tool class."""
25+
26+
def test_creation_with_name_and_description(self):
27+
"""Test creating a tool with explicit name and description."""
28+
tool = Tool(
29+
callable=sample_function,
30+
name="custom_name",
31+
description="Custom description",
32+
read=True,
33+
)
34+
35+
assert tool.name == "custom_name"
36+
assert tool.description == "Custom description"
37+
assert tool.read is True
38+
assert tool.write is False
39+
assert tool.execute is False
40+
assert tool.delete is False
41+
42+
def test_creation_auto_name_description(self):
43+
"""Test creating a tool with auto-generated name and description."""
44+
tool = Tool(callable=another_function, write=True)
45+
46+
assert tool.name == "another_function"
47+
assert tool.description == "Greet someone by name."
48+
assert tool.read is False
49+
assert tool.write is True
50+
51+
def test_creation_no_docstring(self):
52+
"""Test creating a tool with function that has no docstring."""
53+
tool = Tool(callable=no_doc_function)
54+
55+
assert tool.name == "no_doc_function"
56+
assert tool.description == ""
57+
58+
def test_equality(self):
59+
"""Test tool equality based on name."""
60+
tool1 = Tool(callable=sample_function, name="test")
61+
tool2 = Tool(callable=sample_function, name="test")
62+
tool3 = Tool(callable=sample_function, name="different")
63+
64+
assert tool1 == tool2 # Same name
65+
assert tool1 != tool3 # Different name
66+
assert tool1 != "not_a_tool" # Different type
67+
68+
def test_hash(self):
69+
"""Test tool hashing based on name."""
70+
tool1 = Tool(callable=sample_function, name="test")
71+
tool2 = Tool(callable=sample_function, name="test")
72+
73+
assert hash(tool1) == hash(tool2) # Same name, same hash
74+
75+
76+
class TestToolkit:
77+
"""Test the Toolkit class."""
78+
79+
def test_creation(self):
80+
"""Test creating a toolkit."""
81+
toolkit = Toolkit(name="TestToolkit", description="A test toolkit")
82+
83+
assert toolkit.name == "TestToolkit"
84+
assert toolkit.description == "A test toolkit"
85+
assert len(toolkit.tools) == 0
86+
87+
def test_add_tool(self):
88+
"""Test adding a tool to toolkit."""
89+
toolkit = Toolkit(name="TestToolkit")
90+
tool = Tool(callable=sample_function, read=True)
91+
92+
toolkit.add_tool(tool)
93+
assert len(toolkit.tools) == 1
94+
assert tool in toolkit.tools
95+
96+
def test_find_tools(self):
97+
"""Test finding tools with various capability filters."""
98+
# Create a toolkit
99+
toolkit = Toolkit(name="test_toolkit")
100+
101+
# Create tools with different permission combinations
102+
read_only_tool = Tool(callable=lambda: None, name="read_only", read=True)
103+
write_tool = Tool(callable=lambda: None, name="write_tool", write=True)
104+
execute_tool = Tool(callable=lambda: None, name="execute_tool", execute=True)
105+
delete_tool = Tool(callable=lambda: None, name="delete_tool", delete=True)
106+
read_execute_tool = Tool(
107+
callable=lambda: None, name="read_execute", read=True, execute=True
108+
)
109+
write_execute_tool = Tool(
110+
callable=lambda: None, name="write_execute", write=True, execute=True
111+
)
112+
all_perms_tool = Tool(
113+
callable=lambda: None,
114+
name="all_perms",
115+
read=True,
116+
write=True,
117+
execute=True,
118+
delete=True,
119+
)
120+
121+
# Add tools to the toolkit
122+
toolkit.add_tool(read_only_tool)
123+
toolkit.add_tool(write_tool)
124+
toolkit.add_tool(execute_tool)
125+
toolkit.add_tool(delete_tool)
126+
toolkit.add_tool(read_execute_tool)
127+
toolkit.add_tool(write_execute_tool)
128+
toolkit.add_tool(all_perms_tool)
129+
130+
# Test 1: Default parameters (all None) - should return all tools
131+
default_tools = toolkit.get_tools()
132+
assert (
133+
len(default_tools) == 7
134+
), "All tools should be returned with default parameters"
135+
assert read_only_tool in default_tools
136+
assert write_tool in default_tools
137+
assert execute_tool in default_tools
138+
assert delete_tool in default_tools
139+
assert read_execute_tool in default_tools
140+
assert write_execute_tool in default_tools
141+
assert all_perms_tool in default_tools
142+
143+
# Test 2: Find tools with read permission
144+
read_tools = toolkit.get_tools(read=True)
145+
assert len(read_tools) == 3
146+
assert read_only_tool in read_tools
147+
assert read_execute_tool in read_tools
148+
assert all_perms_tool in read_tools
149+
assert write_tool not in read_tools
150+
assert execute_tool not in read_tools
151+
assert delete_tool not in read_tools
152+
assert write_execute_tool not in read_tools
153+
154+
# Test 3: Find tools with write permission
155+
write_tools = toolkit.get_tools(write=True)
156+
assert len(write_tools) == 3
157+
assert write_tool in write_tools
158+
assert write_execute_tool in write_tools
159+
assert all_perms_tool in write_tools
160+
assert read_only_tool not in write_tools
161+
assert read_execute_tool not in write_tools
162+
assert execute_tool not in write_tools
163+
assert delete_tool not in write_tools
164+
165+
# Test 4: Find tools with execute permission
166+
execute_tools = toolkit.get_tools(execute=True)
167+
assert len(execute_tools) == 4
168+
assert execute_tool in execute_tools
169+
assert read_execute_tool in execute_tools
170+
assert write_execute_tool in execute_tools
171+
assert all_perms_tool in execute_tools
172+
assert read_only_tool not in execute_tools
173+
assert write_tool not in execute_tools
174+
assert delete_tool not in execute_tools
175+
176+
# Test 5: Find tools with delete permission
177+
delete_tools = toolkit.get_tools(delete=True)
178+
assert len(delete_tools) == 2
179+
assert delete_tool in delete_tools
180+
assert all_perms_tool in delete_tools
181+
assert read_only_tool not in delete_tools
182+
assert write_tool not in delete_tools
183+
assert execute_tool not in delete_tools
184+
assert read_execute_tool not in delete_tools
185+
assert write_execute_tool not in delete_tools
186+
187+
# Test 6: Combined permissions (read and execute)
188+
read_execute_tools = toolkit.get_tools(read=True, execute=True)
189+
assert len(read_execute_tools) == 2
190+
assert read_execute_tool in read_execute_tools
191+
assert all_perms_tool in read_execute_tools
192+
assert read_only_tool not in read_execute_tools
193+
assert write_tool not in read_execute_tools
194+
assert execute_tool not in read_execute_tools
195+
assert write_execute_tool not in read_execute_tools
196+
assert delete_tool not in read_execute_tools
197+
198+
# Test 7: Combined permissions (read and write)
199+
read_write_tools = toolkit.get_tools(read=True, write=True)
200+
assert len(read_write_tools) == 1
201+
assert all_perms_tool in read_write_tools
202+
assert read_only_tool not in read_write_tools
203+
assert write_tool not in read_write_tools
204+
assert execute_tool not in read_write_tools
205+
assert read_execute_tool not in read_write_tools
206+
assert write_execute_tool not in read_write_tools
207+
assert delete_tool not in read_write_tools
208+
209+
# Test 8: Combined permissions (write and execute)
210+
write_execute_tools = toolkit.get_tools(write=True, execute=True)
211+
assert len(write_execute_tools) == 2
212+
assert write_execute_tool in write_execute_tools
213+
assert all_perms_tool in write_execute_tools
214+
assert read_only_tool not in write_execute_tools
215+
assert write_tool not in write_execute_tools
216+
assert execute_tool not in write_execute_tools
217+
assert read_execute_tool not in write_execute_tools
218+
assert delete_tool not in write_execute_tools
219+
220+
# Test 9: Combined permissions (read, write, and execute)
221+
read_write_execute_tools = toolkit.get_tools(
222+
read=True, write=True, execute=True
223+
)
224+
assert len(read_write_execute_tools) == 1
225+
assert all_perms_tool in read_write_execute_tools
226+
assert read_only_tool not in read_write_execute_tools
227+
assert write_tool not in read_write_execute_tools
228+
assert execute_tool not in read_write_execute_tools
229+
assert read_execute_tool not in read_write_execute_tools
230+
assert write_execute_tool not in read_write_execute_tools
231+
assert delete_tool not in read_write_execute_tools
232+
233+
# Test 10: All permissions
234+
all_perm_tools = toolkit.get_tools(
235+
read=True, write=True, execute=True, delete=True
236+
)
237+
assert len(all_perm_tools) == 1
238+
assert all_perms_tool in all_perm_tools
239+
assert read_only_tool not in all_perm_tools
240+
assert write_tool not in all_perm_tools
241+
assert execute_tool not in all_perm_tools
242+
assert read_execute_tool not in all_perm_tools
243+
assert write_execute_tool not in all_perm_tools
244+
assert delete_tool not in all_perm_tools
245+
246+
write_not_execute_tools = toolkit.get_tools(write=True, execute=False)
247+
assert len(write_not_execute_tools) == 1
248+
assert write_tool in write_not_execute_tools
249+
assert execute_tool not in write_not_execute_tools
250+
assert read_execute_tool not in write_not_execute_tools
251+
assert all_perms_tool not in write_not_execute_tools
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Tools package for Jupyter AI."""
2+
3+
from .models import Tool, Toolkit
4+
5+
__all__ = ["Tool", "Toolkit"]

0 commit comments

Comments
 (0)