diff --git a/docs/tutorial.md b/docs/tutorial.md index 288c3587..913e2fd0 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -20,6 +20,7 @@ - [Chapter 9: Interoperability with Other Frameworks](#chapter-9-interoperability-with-other-frameworks) - [Chapter 10: Prompt Engineering for Mellea](#chapter-10-prompt-engineering-for-m) - [Custom Templates](#custom-templates) +- [Chapter 11: Tool Calling](#chapter-11-tool-calling) - [Appendix: Contributing to Melles](#appendix-contributing-to-mellea) ## Chapter 1: What Is Generative Programming @@ -1281,6 +1282,33 @@ To customize the template and template representation of an existing class, simp See [`mellea/docs/examples/mify/rich_document_advanced.py`](./examples/mify/rich_document_advanced.py) +## Chapter 11: Tool Calling +Mellea supports tool calling for providers/models that support it. Most session level functions support setting a tool_calls boolean. Setting this to true allows tools to be called, but there's no guarantee that a model will call them. + +Tools can be made available for the model to call in a few ways: +1. Components: components can have a TemplateRepresentation object that contains tools. +2. Context: depending on the context, the components in that context can be used as sources of additional tools in the exact same way they would if they were the current action. +3. `ModelOptions.TOOLS`: model options can include a tools parameter. The preferred way of passing these tools is as a list of function objects. + +Currently, tools are identified by the name of the function. If there are conflicts, the most recent tool with that name will be preferred. This means the tools available to the model will have the same priority listed above: +1. Tools from the current component will always be included +2. Tools from the context will be included if there are no name conflicts. A given context can decide what tools to surface, but in most cases, tools from the most recent component in the context will take priority over tools from older requests. +3. Tools from `ModelOptions.TOOLS` will only be added if they do not conflict with any of the above functions. + +For examples on adding tools to the template representation of a component, see the `Table` object in [richdocument.py](../mellea/stdlib/docs/richdocument.py). + +Here's an example of adding a tool through model options. This can be useful when you want to add a tool like web search that should almost always be available: +```python +from mellea.backends.types import ModelOption + +def web_search(query: str) -> str: + ... + +model_opts = { + ModelOptions.TOOLS: [web_search] +} +``` + ## Appendix: Contributing to Mellea ### Contributor Guide: Requirements and Verifiers diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index b3c4d09a..360437bf 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -10,9 +10,8 @@ import datetime import inspect import json -import os from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import outlines import outlines_core @@ -32,8 +31,9 @@ from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier from mellea.backends.tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, convert_tools_to_json, - get_tools_from_action, parse_tools, ) from mellea.backends.types import ModelOption @@ -45,7 +45,6 @@ GenerateLog, ModelOutputThunk, ModelToolCall, - TemplateRepresentation, ) from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement @@ -325,28 +324,15 @@ def _generate_from_context_standard( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: - if isinstance(action, Component) and isinstance( - action.format_for_llm(), TemplateRepresentation - ): - tools = get_tools_from_action(action) - - model_options_tools = model_options.get(ModelOption.TOOLS, None) - if model_options_tools is not None: - assert isinstance(model_options_tools, dict) - for fn_name in model_options_tools: - # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools - assert fn_name not in tools.keys(), ( - f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." - ) - # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. - assert type(fn_name) is str, ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - assert callable(model_options_tools[fn_name]), ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - # Add the model_options tool to the existing set of tools. - tools[fn_name] = model_options_tools[fn_name] + add_tools_from_model_options(tools, model_options) + add_tools_from_context_actions( + tools, ctx.actions_for_available_tools() + ) + + # Add the tools from the action for this generation last so that + # they overwrite conflicting names. + add_tools_from_context_actions(tools, [action]) + FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") seed = model_options.get(ModelOption.SEED, None) if seed is not None: diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 89f89259..15b7fd7d 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -2,7 +2,6 @@ import asyncio import datetime -import os from collections.abc import Callable from typing import Any @@ -13,7 +12,10 @@ from mellea.backends import BaseModelSubclass from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import get_tools_from_action +from mellea.backends.tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, +) from mellea.backends.types import ModelOption from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( @@ -295,28 +297,12 @@ def generate_from_chat_context( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: - if isinstance(action, Component) and isinstance( - action.format_for_llm(), TemplateRepresentation - ): - tools = get_tools_from_action(action) - - model_options_tools = model_opts.get(ModelOption.TOOLS, None) - if model_options_tools is not None: - assert isinstance(model_options_tools, dict) - for fn_name in model_options_tools: - # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools - assert fn_name not in tools.keys(), ( - f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." - ) - # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. - assert type(fn_name) is str, ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - assert callable(model_options_tools[fn_name]), ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - # Add the model_options tool to the existing set of tools. - tools[fn_name] = model_options_tools[fn_name] + add_tools_from_model_options(tools, model_opts) + add_tools_from_context_actions(tools, ctx.actions_for_available_tools()) + + # Add the tools from the action for this generation last so that + # they overwrite conflicting names. + add_tools_from_context_actions(tools, [action]) FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") # Generate a chat response from ollama, using the chat messages. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 4496bf54..9e9a9557 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -20,7 +20,11 @@ from mellea.backends.aloras import Alora, AloraBackendMixin from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import convert_tools_to_json, get_tools_from_action +from mellea.backends.tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, + convert_tools_to_json, +) from mellea.backends.types import ModelOption from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( @@ -30,7 +34,6 @@ GenerateLog, ModelOutputThunk, ModelToolCall, - TemplateRepresentation, ) from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement @@ -404,28 +407,13 @@ def _generate_from_chat_context_standard( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: - if isinstance(action, Component) and isinstance( - action.format_for_llm(), TemplateRepresentation - ): - tools = get_tools_from_action(action) - - model_options_tools = model_opts.get(ModelOption.TOOLS, None) - if model_options_tools is not None: - assert isinstance(model_options_tools, dict) - for fn_name in model_options_tools: - # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools - assert fn_name not in tools.keys(), ( - f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." - ) - # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. - assert type(fn_name) is str, ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - assert callable(model_options_tools[fn_name]), ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - # Add the model_options tool to the existing set of tools. - tools[fn_name] = model_options_tools[fn_name] + add_tools_from_model_options(tools, model_opts) + add_tools_from_context_actions(tools, ctx.actions_for_available_tools()) + + # Add the tools from the action for this generation last so that + # they overwrite conflicting names. + add_tools_from_context_actions(tools, [action]) + FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") thinking = model_opts.get(ModelOption.THINKING, None) if type(thinking) is bool and thinking: diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 49681806..272860af 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -1,27 +1,64 @@ """Utilities for dealing with tools.""" import json -from collections.abc import Callable, Generator, Mapping +from collections.abc import Callable, Generator, Iterable, Mapping from typing import Any from ollama._utils import convert_function_to_tool -from mellea.stdlib.base import Component, TemplateRepresentation +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, Component, TemplateRepresentation -def get_tools_from_action(action: Any) -> dict[str, Callable]: - """If an object is a Component with a TemplateRepresentation, grabs it's tools field. +def add_tools_from_model_options( + tools_dict: dict[str, Callable], model_options: dict[str, Any] +): + """If model_options has tools, add those tools to the tools_dict.""" + model_opts_tools = model_options.get(ModelOption.TOOLS, None) + if model_opts_tools is None: + return + + # Mappings are iterable. + assert isinstance(model_opts_tools, Iterable), ( + "ModelOption.TOOLS must be a list of Callables or dict[str, Callable]" + ) + + if isinstance(model_opts_tools, Mapping): + # Handle the dict case. + for func_name, func in model_opts_tools.items(): + assert isinstance(func_name, str), ( + f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Callable]; found {type(func_name)} as the key instead" + ) + assert callable(func), ( + f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Callable]; found {type(func)} as the value instead" + ) + tools_dict[func_name] = func + else: + # Handle any other iterable / list here. + for func in model_opts_tools: + assert callable(func), ( + f"If ModelOption.TOOLS is a list, it must be a list of Callables; found {type(func)}" + ) + tools_dict[func.__name__] = func + + +def add_tools_from_context_actions( + tools_dict: dict[str, Callable], ctx_actions: list[Component | CBlock] | None +): + """If any of the actions in ctx_actions have tools in their template_representation, add those to the tools_dict.""" + if ctx_actions is None: + return + + for action in ctx_actions: + if not isinstance(action, Component): + continue # Only components have template representations. - Returns: - dict: mapping function names to callables - """ - if isinstance(action, Component): tr = action.format_for_llm() - if isinstance(tr, TemplateRepresentation): - if tr.tools: - return tr.tools + if not isinstance(tr, TemplateRepresentation) or tr.tools is None: + continue - return {} + for tool_name, func in tr.tools.items(): + tools_dict[tool_name] = func def convert_tools_to_json(tools: dict[str, Callable]) -> list[dict]: diff --git a/mellea/backends/types.py b/mellea/backends/types.py index 1b5c4da4..2b6f2e26 100644 --- a/mellea/backends/types.py +++ b/mellea/backends/types.py @@ -17,6 +17,8 @@ class ModelOption: """ TOOLS = "@@@tools@@@" + """Must be a list of callables or a dict[str, Callable].""" + MAX_NEW_TOKENS = "@@@max_new_tokens@@@" SYSTEM_PROMPT = "@@@system_prompt@@@" TEMPERATURE = "temperature" diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 84204db2..c7ed0b8a 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -13,7 +13,11 @@ from mellea.backends import BaseModelSubclass, model_ids from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier -from mellea.backends.tools import convert_tools_to_json, get_tools_from_action +from mellea.backends.tools import ( + add_tools_from_context_actions, + add_tools_from_model_options, + convert_tools_to_json, +) from mellea.backends.types import ModelOption from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( @@ -261,25 +265,13 @@ def generate_from_chat_context( f"tool calling is superseded by format; will not call tools for request: {action}" ) else: - tools = get_tools_from_action(action) - - model_options_tools = model_opts.get(ModelOption.TOOLS, None) - if model_options_tools is not None: - assert isinstance(model_options_tools, dict) - for fn_name in model_options_tools: - # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools - assert fn_name not in tools.keys(), ( - f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." - ) - # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. - assert type(fn_name) is str, ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - assert callable(model_options_tools[fn_name]), ( - "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." - ) - # Add the model_options tool to the existing set of tools. - tools[fn_name] = model_options_tools[fn_name] + add_tools_from_model_options(tools, model_opts) + add_tools_from_context_actions(tools, ctx.actions_for_available_tools()) + + # Add the tools from the action for this generation last so that + # they overwrite conflicting names. + add_tools_from_context_actions(tools, [action]) + FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") formatted_tools = convert_tools_to_json(tools) chat_response = self._model.chat( diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 8f9fd817..d48cb6d7 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -161,6 +161,14 @@ def render_for_generation(self) -> list[Component | CBlock] | None: """Provides a linear list of context components to use for generation, or None if that is not possible to construct.""" ... + @abc.abstractmethod + def actions_for_available_tools(self) -> list[Component | CBlock] | None: + """Provides a list of actions to extract tools from for use with during generation, or None if that's not possible. + + Can be used to make the available tools differ from the tools of all the actions in the context. + """ + ... + @abc.abstractmethod def full_event_log(self) -> list[Component | CBlock]: """Provides a list of all events stored in the context.""" @@ -210,6 +218,14 @@ def __init__(self): self._ctx = [] self._log_ctx = [] + def actions_for_available_tools(self) -> list[Component | CBlock] | None: + """Provides a list of actions to extract tools from for use with during generation, or None if that's not possible. + + Can be used to make the available tools differ from the tools of all the actions in the context. + In most cases, this will just be the same context as `render_for_generation`. + """ + return self.render_for_generation() + def last_output(self): """The last output thunk of the context.""" for c in self._ctx[::-1]: diff --git a/test/backends/test_tool_helpers.py b/test/backends/test_tool_helpers.py new file mode 100644 index 00000000..021913f0 --- /dev/null +++ b/test/backends/test_tool_helpers.py @@ -0,0 +1,101 @@ + +import pytest +from mellea.backends.tools import add_tools_from_context_actions, add_tools_from_model_options +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, Component, TemplateRepresentation + +class FakeToolComponent(Component): + def __init__(self) -> None: + super().__init__() + + def tool1(self): + return + + def parts(self): + return [] + + def format_for_llm(self) -> TemplateRepresentation: + return TemplateRepresentation( + obj=self, + args={"arg": None}, + tools={ + self.tool1.__name__: self.tool1 + } + ) + +class FakeToolComponentWithExtraTool(FakeToolComponent): + def __init__(self) -> None: + super().__init__() + + def tool2(self): + return + + def format_for_llm(self) -> TemplateRepresentation: + tr = super().format_for_llm() + assert tr.tools is not None + tr.tools[self.tool2.__name__] = self.tool2 + return tr + + +def test_add_tools_from_model_options_list(): + def get_weather(location: str) -> int: + """Returns the weather in Celsius.""" + return 21 + + ftc = FakeToolComponent() + model_options = { + ModelOption.TOOLS: [get_weather, ftc.tool1] + } + + tools = {} + add_tools_from_model_options(tools, model_options) + + assert tools["get_weather"] == get_weather + + # Must use `==` for bound methods. + tool1 = tools['tool1'] + assert tool1 == ftc.tool1, f"{tool1} should == {ftc.tool1}" + + +def test_add_tools_from_model_options_map(): + def get_weather(location: str) -> int: + """Returns the weather in Celsius.""" + return 21 + + ftc = FakeToolComponent() + model_options = { + ModelOption.TOOLS: { + get_weather.__name__: get_weather, + ftc.tool1.__name__: ftc.tool1 + } + } + + tools = {} + add_tools_from_model_options(tools, model_options) + + assert tools["get_weather"] == get_weather + + # Must use `==` for bound methods. + tool1 = tools['tool1'] + assert tool1 == ftc.tool1, f"{tool1} should == {ftc.tool1}" + + +def test_add_tools_from_context_actions(): + + ftc1 = FakeToolComponentWithExtraTool() + ftc2 = FakeToolComponent() + + ctx_actions = [CBlock("Hello"), ftc1, ftc2] + tools = {} + add_tools_from_context_actions(tools, ctx_actions) + + # Check that tools with the same name get properly overwritten in order of ctx. + tool1 = tools['tool1'] + assert tool1 == ftc2.tool1, f"{tool1} should == {ftc2.tool1}" + + # Check that tools that aren't overwritten are still there. + tool2 = tools["tool2"] + assert tool2 == ftc1.tool2, f"{tool2} should == {ftc1.tool2}" + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/stdlib_basics/test_base.py b/test/stdlib_basics/test_base.py index 6d4008bd..3353742c 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/stdlib_basics/test_base.py @@ -1,3 +1,4 @@ +import pytest from mellea.stdlib.base import CBlock, Component, LinearContext @@ -32,3 +33,21 @@ def test_context(): ctx.insert(CBlock("b")) ctx.insert(CBlock("c")) ctx.insert(CBlock("d")) + + +def test_actions_for_available_tools(): + ctx = LinearContext(window_size=3) + ctx.insert(CBlock("a")) + ctx.insert(CBlock("b")) + for_generation = ctx.render_for_generation() + assert for_generation is not None + + actions = ctx.actions_for_available_tools() + assert actions is not None + + assert len(for_generation) == len(actions) + for i in range(len(actions)): + assert actions[i] == for_generation[i] + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_tool_calls.py b/test/test_tool_calls.py index d669011a..4b966040 100644 --- a/test/test_tool_calls.py +++ b/test/test_tool_calls.py @@ -2,7 +2,9 @@ from mellea.backends import Backend from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import ModelOutputThunk +from mellea.backends.tools import add_tools_from_context_actions, add_tools_from_model_options +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, Component, ModelOutputThunk, TemplateRepresentation from mellea.stdlib.docs.richdocument import Table from mellea.stdlib.session import LinearContext, MelleaSession @@ -28,9 +30,34 @@ def table() -> Table: return t +def test_tool_called_from_context_action(m: MelleaSession, table: Table): + """Make sure tools can be called from actions in the context.""" + r = 10 + m.ctx.reset() + + # Insert a component with tools into the context. + m.ctx.insert(table) + + returned_tool = False + for i in range(r): + # Make sure the specific generate call is on a different action with + # no tools to make sure it's a tool from the context. + result = m.backend.generate_from_context( + CBlock("Add a row to the table."), + m.ctx, + tool_calls=True + ) + if result.tool_calls is not None and len(result.tool_calls) > 0: + returned_tool = True + break + + assert returned_tool, f"did not return a tool after {r} attempts" + + def test_tool_called(m: MelleaSession, table: Table): """We don't force tools to be called. As a result, this test might unexpectedly fail.""" r = 10 + m.ctx.reset() returned_tool = False for i in range(r): @@ -45,6 +72,7 @@ def test_tool_called(m: MelleaSession, table: Table): def test_tool_not_called(m: MelleaSession, table: Table): """Ensure tools aren't always called when provided.""" r = 10 + m.ctx.reset() returned_no_tool = False for i in range(r):