Skip to content

Commit b859ca1

Browse files
committed
allow modeloptions.tools to be mapping or iterable; add func for extracting tools from list of actions
1 parent fdb0a12 commit b859ca1

File tree

3 files changed

+142
-37
lines changed

3 files changed

+142
-37
lines changed

mellea/backends/tools.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,58 @@
77
from ollama._utils import convert_function_to_tool
88

99
from mellea.backends.types import ModelOption
10-
from mellea.stdlib.base import Component, TemplateRepresentation
10+
from mellea.stdlib.base import CBlock, Component, TemplateRepresentation
1111

1212

1313
def add_tools_from_model_options(
1414
tools_dict: dict[str, Callable], model_options: dict[str, Any]
1515
):
16-
"""If model_options has tools, it will add those tools to the tools_dict."""
16+
"""If model_options has tools, add those tools to the tools_dict."""
1717
model_opts_tools = model_options.get(ModelOption.TOOLS, None)
18-
1918
if model_opts_tools is None:
2019
return
2120

21+
# Mappings are iterable.
2222
assert isinstance(model_opts_tools, Iterable), (
23-
"ModelOption.TOOLS must be a list of Callables"
23+
"ModelOption.TOOLS must be a list of Callables or dict[str, Callable]"
2424
)
25-
for func in model_opts_tools:
26-
assert callable(func), (
27-
f"ModelOption.TOOLS must be a list of Callables, found {type(func)}"
28-
)
29-
tools_dict[func.__name__] = func
25+
26+
if isinstance(model_opts_tools, Mapping):
27+
# Handle the dict case.
28+
for func_name, func in model_opts_tools.items():
29+
assert isinstance(func_name, str), (
30+
f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Callable]; found {type(func_name)} as the key instead"
31+
)
32+
assert callable(func), (
33+
f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Callable]; found {type(func)} as the value instead"
34+
)
35+
tools_dict[func_name] = func
36+
else:
37+
# Handle any other iterable / list here.
38+
for func in model_opts_tools:
39+
assert callable(func), (
40+
f"If ModelOption.TOOLS is a list, it must be a list of Callables; found {type(func)}"
41+
)
42+
tools_dict[func.__name__] = func
43+
44+
45+
def add_tools_from_context_actions(
46+
tools_dict: dict[str, Callable], ctx_actions: list[Component | CBlock] | None
47+
):
48+
"""If any of the actions in ctx_actions have tools in their template_representation, add those to the tools_dict."""
49+
if ctx_actions is None:
50+
return
51+
52+
for action in ctx_actions:
53+
if not isinstance(action, Component):
54+
continue # Only components have template representations.
55+
56+
tr = action.format_for_llm()
57+
if isinstance(tr, str) or tr.tools is None:
58+
continue
59+
60+
for tool_name, func in tr.tools.items():
61+
tools_dict[tool_name] = func
3062

3163

3264
def get_tools_from_action(action: Any) -> dict[str, Callable]:

mellea/backends/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ModelOption:
1717
"""
1818

1919
TOOLS = "@@@tools@@@"
20-
"""Must be a list of callables."""
20+
"""Must be a list of callables or a dict[str, Callable]."""
2121

2222
MAX_NEW_TOKENS = "@@@max_new_tokens@@@"
2323
SYSTEM_PROMPT = "@@@system_prompt@@@"

test/test_tool_calls.py

Lines changed: 100 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from mellea.backends import Backend
44
from mellea.backends.ollama import OllamaModelBackend
5-
from mellea.backends.tools import add_tools_from_model_options
5+
from mellea.backends.tools import add_tools_from_context_actions, add_tools_from_model_options
66
from mellea.backends.types import ModelOption
7-
from mellea.stdlib.base import ModelOutputThunk
7+
from mellea.stdlib.base import CBlock, Component, ModelOutputThunk, TemplateRepresentation
88
from mellea.stdlib.docs.richdocument import Table
99
from mellea.stdlib.session import LinearContext, MelleaSession
1010

@@ -29,13 +29,47 @@ def table() -> Table:
2929
assert t is not None, "test setup failed: could not create table from markdown"
3030
return t
3131

32-
def test_add_tools_from_model_options(table: Table):
32+
class FakeToolComponent(Component):
33+
def __init__(self) -> None:
34+
super().__init__()
35+
36+
def tool1(self):
37+
return
38+
39+
def parts(self):
40+
return []
41+
42+
def format_for_llm(self) -> TemplateRepresentation:
43+
return TemplateRepresentation(
44+
obj=self,
45+
args={"arg": None},
46+
tools={
47+
self.tool1.__name__: self.tool1
48+
}
49+
)
50+
51+
class FakeToolComponentWithExtraTool(FakeToolComponent):
52+
def __init__(self) -> None:
53+
super().__init__()
54+
55+
def tool2(self):
56+
return
57+
58+
def format_for_llm(self) -> TemplateRepresentation:
59+
tr = super().format_for_llm()
60+
assert tr.tools is not None
61+
tr.tools[self.tool2.__name__] = self.tool2
62+
return tr
63+
64+
65+
def test_add_tools_from_model_options_list():
3366
def get_weather(location: str) -> int:
3467
"""Returns the weather in Celsius."""
3568
return 21
3669

70+
ftc = FakeToolComponent()
3771
model_options = {
38-
ModelOption.TOOLS: [get_weather, table.content_as_string]
72+
ModelOption.TOOLS: [get_weather, ftc.tool1]
3973
}
4074

4175
tools = {}
@@ -44,36 +78,75 @@ def get_weather(location: str) -> int:
4478
assert tools["get_weather"] == get_weather
4579

4680
# Must use `==` for bound methods.
47-
assert tools["content_as_string"] == table.content_as_string, f"{tools["content_as_string"]} is not {table.content_as_string}"
81+
assert tools["tool1"] == ftc.tool1, f"{tools["tool1"]} should == {ftc.tool1}"
4882

49-
def test_tool_called(m: MelleaSession, table: Table):
50-
"""We don't force tools to be called. As a result, this test might unexpectedly fail."""
51-
r = 10
5283

53-
returned_tool = False
54-
for i in range(r):
55-
transformed = m.transform(table, "add a new row to this table")
56-
if isinstance(transformed, Table):
57-
returned_tool = True
58-
break
84+
def test_add_tools_from_model_options_map():
85+
def get_weather(location: str) -> int:
86+
"""Returns the weather in Celsius."""
87+
return 21
88+
89+
ftc = FakeToolComponent()
90+
model_options = {
91+
ModelOption.TOOLS: {
92+
get_weather.__name__: get_weather,
93+
ftc.tool1.__name__: ftc.tool1
94+
}
95+
}
96+
97+
tools = {}
98+
add_tools_from_model_options(tools, model_options)
99+
100+
assert tools["get_weather"] == get_weather
101+
102+
# Must use `==` for bound methods.
103+
assert tools["tool1"] == ftc.tool1, f"{tools["tool1"]} should == {ftc.tool1}"
104+
105+
106+
def test_add_tools_from_context_actions():
107+
108+
ftc1 = FakeToolComponentWithExtraTool()
109+
ftc2 = FakeToolComponent()
110+
111+
ctx_actions = [CBlock("Hello"), ftc1, ftc2]
112+
tools = {}
113+
add_tools_from_context_actions(tools, ctx_actions)
114+
115+
# Check that tools with the same name get properly overwritten in order of ctx.
116+
assert tools["tool1"] == ftc2.tool1, f"{tools["tool1"]} should == {ftc2.tool1}"
117+
118+
# Check that tools that aren't overwritten are still there.
119+
assert tools["tool2"] == ftc1.tool2, f"{tools["tool2"]} should == {ftc1.tool2}"
120+
121+
122+
# def test_tool_called(m: MelleaSession, table: Table):
123+
# """We don't force tools to be called. As a result, this test might unexpectedly fail."""
124+
# r = 10
125+
126+
# returned_tool = False
127+
# for i in range(r):
128+
# transformed = m.transform(table, "add a new row to this table")
129+
# if isinstance(transformed, Table):
130+
# returned_tool = True
131+
# break
59132

60-
assert returned_tool, f"did not return a tool after {r} attempts"
133+
# assert returned_tool, f"did not return a tool after {r} attempts"
61134

62135

63-
def test_tool_not_called(m: MelleaSession, table: Table):
64-
"""Ensure tools aren't always called when provided."""
65-
r = 10
136+
# def test_tool_not_called(m: MelleaSession, table: Table):
137+
# """Ensure tools aren't always called when provided."""
138+
# r = 10
66139

67-
returned_no_tool = False
68-
for i in range(r):
69-
transformed = m.transform(table, "output a text description of this table")
70-
if isinstance(transformed, ModelOutputThunk):
71-
returned_no_tool = True
72-
break
140+
# returned_no_tool = False
141+
# for i in range(r):
142+
# transformed = m.transform(table, "output a text description of this table")
143+
# if isinstance(transformed, ModelOutputThunk):
144+
# returned_no_tool = True
145+
# break
73146

74-
assert (
75-
returned_no_tool
76-
), f"only returned tools after {r} attempts, should've returned a response with no tools"
147+
# assert (
148+
# returned_no_tool
149+
# ), f"only returned tools after {r} attempts, should've returned a response with no tools"
77150

78151
if __name__ == "__main__":
79152
pytest.main([__file__])

0 commit comments

Comments
 (0)