|
1 | 1 | """Utilities for dealing with tools.""" |
2 | 2 |
|
3 | 3 | import json |
4 | | -from collections.abc import Callable, Generator, Mapping |
| 4 | +from collections.abc import Callable, Generator, Iterable, Mapping |
5 | 5 | from typing import Any |
6 | 6 |
|
7 | 7 | from ollama._utils import convert_function_to_tool |
8 | 8 |
|
9 | | -from mellea.stdlib.base import Component, TemplateRepresentation |
| 9 | +from mellea.backends.types import ModelOption |
| 10 | +from mellea.stdlib.base import CBlock, Component, TemplateRepresentation |
10 | 11 |
|
11 | 12 |
|
12 | | -def get_tools_from_action(action: Any) -> dict[str, Callable]: |
13 | | - """If an object is a Component with a TemplateRepresentation, grabs it's tools field. |
| 13 | +def add_tools_from_model_options( |
| 14 | + tools_dict: dict[str, Callable], model_options: dict[str, Any] |
| 15 | +): |
| 16 | + """If model_options has tools, add those tools to the tools_dict.""" |
| 17 | + model_opts_tools = model_options.get(ModelOption.TOOLS, None) |
| 18 | + if model_opts_tools is None: |
| 19 | + return |
| 20 | + |
| 21 | + # Mappings are iterable. |
| 22 | + assert isinstance(model_opts_tools, Iterable), ( |
| 23 | + "ModelOption.TOOLS must be a list of Callables or dict[str, Callable]" |
| 24 | + ) |
| 25 | + |
| 26 | + if isinstance(model_opts_tools, Mapping): |
| 27 | + # Handle the dict case. |
| 28 | + for func_name, func in model_opts_tools.items(): |
| 29 | + assert isinstance(func_name, str), ( |
| 30 | + f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Callable]; found {type(func_name)} as the key instead" |
| 31 | + ) |
| 32 | + assert callable(func), ( |
| 33 | + f"If ModelOption.TOOLS is a dict, it must be a dict of [str, Callable]; found {type(func)} as the value instead" |
| 34 | + ) |
| 35 | + tools_dict[func_name] = func |
| 36 | + else: |
| 37 | + # Handle any other iterable / list here. |
| 38 | + for func in model_opts_tools: |
| 39 | + assert callable(func), ( |
| 40 | + f"If ModelOption.TOOLS is a list, it must be a list of Callables; found {type(func)}" |
| 41 | + ) |
| 42 | + tools_dict[func.__name__] = func |
| 43 | + |
| 44 | + |
| 45 | +def add_tools_from_context_actions( |
| 46 | + tools_dict: dict[str, Callable], ctx_actions: list[Component | CBlock] | None |
| 47 | +): |
| 48 | + """If any of the actions in ctx_actions have tools in their template_representation, add those to the tools_dict.""" |
| 49 | + if ctx_actions is None: |
| 50 | + return |
| 51 | + |
| 52 | + for action in ctx_actions: |
| 53 | + if not isinstance(action, Component): |
| 54 | + continue # Only components have template representations. |
14 | 55 |
|
15 | | - Returns: |
16 | | - dict: mapping function names to callables |
17 | | - """ |
18 | | - if isinstance(action, Component): |
19 | 56 | tr = action.format_for_llm() |
20 | | - if isinstance(tr, TemplateRepresentation): |
21 | | - if tr.tools: |
22 | | - return tr.tools |
| 57 | + if not isinstance(tr, TemplateRepresentation) or tr.tools is None: |
| 58 | + continue |
23 | 59 |
|
24 | | - return {} |
| 60 | + for tool_name, func in tr.tools.items(): |
| 61 | + tools_dict[tool_name] = func |
25 | 62 |
|
26 | 63 |
|
27 | 64 | def convert_tools_to_json(tools: dict[str, Callable]) -> list[dict]: |
|
0 commit comments