Skip to content

Commit 4d912a1

Browse files
committed
move to separate test file
1 parent 45507ea commit 4d912a1

File tree

2 files changed

+97
-89
lines changed

2 files changed

+97
-89
lines changed

test/backends/test_tool_helpers.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
2+
import pytest
3+
from mellea.backends.tools import add_tools_from_context_actions, add_tools_from_model_options
4+
from mellea.backends.types import ModelOption
5+
from mellea.stdlib.base import CBlock, Component, TemplateRepresentation
6+
7+
class FakeToolComponent(Component):
8+
def __init__(self) -> None:
9+
super().__init__()
10+
11+
def tool1(self):
12+
return
13+
14+
def parts(self):
15+
return []
16+
17+
def format_for_llm(self) -> TemplateRepresentation:
18+
return TemplateRepresentation(
19+
obj=self,
20+
args={"arg": None},
21+
tools={
22+
self.tool1.__name__: self.tool1
23+
}
24+
)
25+
26+
class FakeToolComponentWithExtraTool(FakeToolComponent):
27+
def __init__(self) -> None:
28+
super().__init__()
29+
30+
def tool2(self):
31+
return
32+
33+
def format_for_llm(self) -> TemplateRepresentation:
34+
tr = super().format_for_llm()
35+
assert tr.tools is not None
36+
tr.tools[self.tool2.__name__] = self.tool2
37+
return tr
38+
39+
40+
def test_add_tools_from_model_options_list():
41+
def get_weather(location: str) -> int:
42+
"""Returns the weather in Celsius."""
43+
return 21
44+
45+
ftc = FakeToolComponent()
46+
model_options = {
47+
ModelOption.TOOLS: [get_weather, ftc.tool1]
48+
}
49+
50+
tools = {}
51+
add_tools_from_model_options(tools, model_options)
52+
53+
assert tools["get_weather"] == get_weather
54+
55+
# Must use `==` for bound methods.
56+
assert tools["tool1"] == ftc.tool1, f"{tools["tool1"]} should == {ftc.tool1}"
57+
58+
59+
def test_add_tools_from_model_options_map():
60+
def get_weather(location: str) -> int:
61+
"""Returns the weather in Celsius."""
62+
return 21
63+
64+
ftc = FakeToolComponent()
65+
model_options = {
66+
ModelOption.TOOLS: {
67+
get_weather.__name__: get_weather,
68+
ftc.tool1.__name__: ftc.tool1
69+
}
70+
}
71+
72+
tools = {}
73+
add_tools_from_model_options(tools, model_options)
74+
75+
assert tools["get_weather"] == get_weather
76+
77+
# Must use `==` for bound methods.
78+
assert tools["tool1"] == ftc.tool1, f"{tools["tool1"]} should == {ftc.tool1}"
79+
80+
81+
def test_add_tools_from_context_actions():
82+
83+
ftc1 = FakeToolComponentWithExtraTool()
84+
ftc2 = FakeToolComponent()
85+
86+
ctx_actions = [CBlock("Hello"), ftc1, ftc2]
87+
tools = {}
88+
add_tools_from_context_actions(tools, ctx_actions)
89+
90+
# Check that tools with the same name get properly overwritten in order of ctx.
91+
assert tools["tool1"] == ftc2.tool1, f"{tools["tool1"]} should == {ftc2.tool1}"
92+
93+
# Check that tools that aren't overwritten are still there.
94+
assert tools["tool2"] == ftc1.tool2, f"{tools["tool2"]} should == {ftc1.tool2}"
95+
96+
if __name__ == "__main__":
97+
pytest.main([__file__])

test/test_tool_calls.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -29,95 +29,6 @@ def table() -> Table:
2929
assert t is not None, "test setup failed: could not create table from markdown"
3030
return t
3131

32-
class FakeToolComponent(Component):
33-
def __init__(self) -> None:
34-
super().__init__()
35-
36-
def tool1(self):
37-
return
38-
39-
def parts(self):
40-
return []
41-
42-
def format_for_llm(self) -> TemplateRepresentation:
43-
return TemplateRepresentation(
44-
obj=self,
45-
args={"arg": None},
46-
tools={
47-
self.tool1.__name__: self.tool1
48-
}
49-
)
50-
51-
class FakeToolComponentWithExtraTool(FakeToolComponent):
52-
def __init__(self) -> None:
53-
super().__init__()
54-
55-
def tool2(self):
56-
return
57-
58-
def format_for_llm(self) -> TemplateRepresentation:
59-
tr = super().format_for_llm()
60-
assert tr.tools is not None
61-
tr.tools[self.tool2.__name__] = self.tool2
62-
return tr
63-
64-
65-
def test_add_tools_from_model_options_list():
66-
def get_weather(location: str) -> int:
67-
"""Returns the weather in Celsius."""
68-
return 21
69-
70-
ftc = FakeToolComponent()
71-
model_options = {
72-
ModelOption.TOOLS: [get_weather, ftc.tool1]
73-
}
74-
75-
tools = {}
76-
add_tools_from_model_options(tools, model_options)
77-
78-
assert tools["get_weather"] == get_weather
79-
80-
# Must use `==` for bound methods.
81-
assert tools["tool1"] == ftc.tool1, f"{tools["tool1"]} should == {ftc.tool1}"
82-
83-
84-
def test_add_tools_from_model_options_map():
85-
def get_weather(location: str) -> int:
86-
"""Returns the weather in Celsius."""
87-
return 21
88-
89-
ftc = FakeToolComponent()
90-
model_options = {
91-
ModelOption.TOOLS: {
92-
get_weather.__name__: get_weather,
93-
ftc.tool1.__name__: ftc.tool1
94-
}
95-
}
96-
97-
tools = {}
98-
add_tools_from_model_options(tools, model_options)
99-
100-
assert tools["get_weather"] == get_weather
101-
102-
# Must use `==` for bound methods.
103-
assert tools["tool1"] == ftc.tool1, f"{tools["tool1"]} should == {ftc.tool1}"
104-
105-
106-
def test_add_tools_from_context_actions():
107-
108-
ftc1 = FakeToolComponentWithExtraTool()
109-
ftc2 = FakeToolComponent()
110-
111-
ctx_actions = [CBlock("Hello"), ftc1, ftc2]
112-
tools = {}
113-
add_tools_from_context_actions(tools, ctx_actions)
114-
115-
# Check that tools with the same name get properly overwritten in order of ctx.
116-
assert tools["tool1"] == ftc2.tool1, f"{tools["tool1"]} should == {ftc2.tool1}"
117-
118-
# Check that tools that aren't overwritten are still there.
119-
assert tools["tool2"] == ftc1.tool2, f"{tools["tool2"]} should == {ftc1.tool2}"
120-
12132

12233
def test_tool_called_from_context_action(m: MelleaSession, table: Table):
12334
"""Make sure tools can be called from actions in the context."""

0 commit comments

Comments
 (0)