Skip to content

Commit fdb0a12

Browse files
committed
switch ModelOptions.TOOLS to be a list; standardize tool handling; add tests
1 parent f92a286 commit fdb0a12

File tree

7 files changed

+60
-73
lines changed

7 files changed

+60
-73
lines changed

mellea/backends/huggingface.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
3333
from mellea.backends.model_ids import ModelIdentifier
3434
from mellea.backends.tools import (
35+
add_tools_from_model_options,
3536
convert_tools_to_json,
3637
get_tools_from_action,
3738
parse_tools,
@@ -329,24 +330,8 @@ def _generate_from_context_standard(
329330
action.format_for_llm(), TemplateRepresentation
330331
):
331332
tools = get_tools_from_action(action)
332-
333-
model_options_tools = model_options.get(ModelOption.TOOLS, None)
334-
if model_options_tools is not None:
335-
assert isinstance(model_options_tools, dict)
336-
for fn_name in model_options_tools:
337-
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
338-
assert fn_name not in tools.keys(), (
339-
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
340-
)
341-
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
342-
assert type(fn_name) is str, (
343-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
344-
)
345-
assert callable(model_options_tools[fn_name]), (
346-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
347-
)
348-
# Add the model_options tool to the existing set of tools.
349-
tools[fn_name] = model_options_tools[fn_name]
333+
add_tools_from_model_options(tools, model_options)
334+
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
350335

351336
seed = model_options.get(ModelOption.SEED, None)
352337
if seed is not None:

mellea/backends/ollama.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from mellea.backends import BaseModelSubclass
1414
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
1515
from mellea.backends.model_ids import ModelIdentifier
16-
from mellea.backends.tools import get_tools_from_action
16+
from mellea.backends.tools import add_tools_from_model_options, get_tools_from_action
1717
from mellea.backends.types import ModelOption
1818
from mellea.helpers.fancy_logger import FancyLogger
1919
from mellea.stdlib.base import (
@@ -300,23 +300,7 @@ def generate_from_chat_context(
300300
):
301301
tools = get_tools_from_action(action)
302302

303-
model_options_tools = model_opts.get(ModelOption.TOOLS, None)
304-
if model_options_tools is not None:
305-
assert isinstance(model_options_tools, dict)
306-
for fn_name in model_options_tools:
307-
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
308-
assert fn_name not in tools.keys(), (
309-
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
310-
)
311-
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
312-
assert type(fn_name) is str, (
313-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
314-
)
315-
assert callable(model_options_tools[fn_name]), (
316-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
317-
)
318-
# Add the model_options tool to the existing set of tools.
319-
tools[fn_name] = model_options_tools[fn_name]
303+
add_tools_from_model_options(tools, model_opts)
320304
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
321305

322306
# Generate a chat response from ollama, using the chat messages.

mellea/backends/openai.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from mellea.backends.aloras import Alora, AloraBackendMixin
2121
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
2222
from mellea.backends.model_ids import ModelIdentifier
23-
from mellea.backends.tools import convert_tools_to_json, get_tools_from_action
23+
from mellea.backends.tools import (
24+
add_tools_from_model_options,
25+
convert_tools_to_json,
26+
get_tools_from_action,
27+
)
2428
from mellea.backends.types import ModelOption
2529
from mellea.helpers.fancy_logger import FancyLogger
2630
from mellea.stdlib.base import (
@@ -409,23 +413,8 @@ def _generate_from_chat_context_standard(
409413
):
410414
tools = get_tools_from_action(action)
411415

412-
model_options_tools = model_opts.get(ModelOption.TOOLS, None)
413-
if model_options_tools is not None:
414-
assert isinstance(model_options_tools, dict)
415-
for fn_name in model_options_tools:
416-
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
417-
assert fn_name not in tools.keys(), (
418-
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
419-
)
420-
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
421-
assert type(fn_name) is str, (
422-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
423-
)
424-
assert callable(model_options_tools[fn_name]), (
425-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
426-
)
427-
# Add the model_options tool to the existing set of tools.
428-
tools[fn_name] = model_options_tools[fn_name]
416+
add_tools_from_model_options(tools, model_opts)
417+
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
429418

430419
thinking = model_opts.get(ModelOption.THINKING, None)
431420
if type(thinking) is bool and thinking:

mellea/backends/tools.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
"""Utilities for dealing with tools."""
22

33
import json
4-
from collections.abc import Callable, Generator, Mapping
4+
from collections.abc import Callable, Generator, Iterable, Mapping
55
from typing import Any
66

77
from ollama._utils import convert_function_to_tool
88

9+
from mellea.backends.types import ModelOption
910
from mellea.stdlib.base import Component, TemplateRepresentation
1011

1112

13+
def add_tools_from_model_options(
14+
tools_dict: dict[str, Callable], model_options: dict[str, Any]
15+
):
16+
"""If model_options has tools, it will add those tools to the tools_dict."""
17+
model_opts_tools = model_options.get(ModelOption.TOOLS, None)
18+
19+
if model_opts_tools is None:
20+
return
21+
22+
assert isinstance(model_opts_tools, Iterable), (
23+
"ModelOption.TOOLS must be a list of Callables"
24+
)
25+
for func in model_opts_tools:
26+
assert callable(func), (
27+
f"ModelOption.TOOLS must be a list of Callables, found {type(func)}"
28+
)
29+
tools_dict[func.__name__] = func
30+
31+
1232
def get_tools_from_action(action: Any) -> dict[str, Callable]:
1333
"""If an object is a Component with a TemplateRepresentation, grabs it's tools field.
1434

mellea/backends/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class ModelOption:
1717
"""
1818

1919
TOOLS = "@@@tools@@@"
20+
"""Must be a list of callables."""
21+
2022
MAX_NEW_TOKENS = "@@@max_new_tokens@@@"
2123
SYSTEM_PROMPT = "@@@system_prompt@@@"
2224
TEMPERATURE = "temperature"

mellea/backends/watsonx.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from mellea.backends import BaseModelSubclass, model_ids
1414
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
1515
from mellea.backends.model_ids import ModelIdentifier
16-
from mellea.backends.tools import convert_tools_to_json, get_tools_from_action
16+
from mellea.backends.tools import (
17+
add_tools_from_model_options,
18+
convert_tools_to_json,
19+
get_tools_from_action,
20+
)
1721
from mellea.backends.types import ModelOption
1822
from mellea.helpers.fancy_logger import FancyLogger
1923
from mellea.stdlib.base import (
@@ -262,24 +266,9 @@ def generate_from_chat_context(
262266
)
263267
else:
264268
tools = get_tools_from_action(action)
269+
add_tools_from_model_options(tools, model_opts)
265270

266-
model_options_tools = model_opts.get(ModelOption.TOOLS, None)
267-
if model_options_tools is not None:
268-
assert isinstance(model_options_tools, dict)
269-
for fn_name in model_options_tools:
270-
# invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools
271-
assert fn_name not in tools.keys(), (
272-
f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action."
273-
)
274-
# type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries.
275-
assert type(fn_name) is str, (
276-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
277-
)
278-
assert callable(model_options_tools[fn_name]), (
279-
"When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function."
280-
)
281-
# Add the model_options tool to the existing set of tools.
282-
tools[fn_name] = model_options_tools[fn_name]
271+
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
283272

284273
formatted_tools = convert_tools_to_json(tools)
285274
chat_response = self._model.chat(

test/test_tool_calls.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from mellea.backends import Backend
44
from mellea.backends.ollama import OllamaModelBackend
5+
from mellea.backends.tools import add_tools_from_model_options
6+
from mellea.backends.types import ModelOption
57
from mellea.stdlib.base import ModelOutputThunk
68
from mellea.stdlib.docs.richdocument import Table
79
from mellea.stdlib.session import LinearContext, MelleaSession
@@ -27,6 +29,22 @@ def table() -> Table:
2729
assert t is not None, "test setup failed: could not create table from markdown"
2830
return t
2931

32+
def test_add_tools_from_model_options(table: Table):
33+
def get_weather(location: str) -> int:
34+
"""Returns the weather in Celsius."""
35+
return 21
36+
37+
model_options = {
38+
ModelOption.TOOLS: [get_weather, table.content_as_string]
39+
}
40+
41+
tools = {}
42+
add_tools_from_model_options(tools, model_options)
43+
44+
assert tools["get_weather"] == get_weather
45+
46+
# Must use `==` for bound methods.
47+
assert tools["content_as_string"] == table.content_as_string, f"{tools["content_as_string"]} is not {table.content_as_string}"
3048

3149
def test_tool_called(m: MelleaSession, table: Table):
3250
"""We don't force tools to be called. As a result, this test might unexpectedly fail."""

0 commit comments

Comments
 (0)