Skip to content

Commit 45507ea

Browse files
committed
update handling of tools for the given action
1 parent 4dea34f commit 45507ea

File tree

5 files changed

+24
-44
lines changed

5 files changed

+24
-44
lines changed

mellea/backends/huggingface.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import datetime
1111
import inspect
1212
import json
13-
import os
1413
from collections.abc import Callable
15-
from typing import TYPE_CHECKING, Any, Optional
14+
from typing import TYPE_CHECKING, Any
1615

1716
import outlines
1817
import outlines_core
@@ -35,7 +34,6 @@
3534
add_tools_from_context_actions,
3635
add_tools_from_model_options,
3736
convert_tools_to_json,
38-
get_tools_from_action,
3937
parse_tools,
4038
)
4139
from mellea.backends.types import ModelOption
@@ -47,7 +45,6 @@
4745
GenerateLog,
4846
ModelOutputThunk,
4947
ModelToolCall,
50-
TemplateRepresentation,
5148
)
5249
from mellea.stdlib.chat import Message
5350
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
@@ -327,13 +324,15 @@ def _generate_from_context_standard(
327324
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
328325
)
329326
else:
330-
if isinstance(action, Component) and isinstance(
331-
action.format_for_llm(), TemplateRepresentation
332-
):
333-
tools = get_tools_from_action(action)
334-
add_tools_from_model_options(tools, model_options)
335-
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
336-
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
327+
add_tools_from_model_options(tools, model_options)
328+
add_tools_from_context_actions(
329+
tools, ctx.actions_for_available_tools()
330+
)
331+
332+
# Add the tools from the action for this generation last so that
333+
# they overwrite conflicting names.
334+
add_tools_from_context_actions(tools, [action])
335+
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
337336

338337
seed = model_options.get(ModelOption.SEED, None)
339338
if seed is not None:

mellea/backends/ollama.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import asyncio
44
import datetime
5-
import os
65
from collections.abc import Callable
76
from typing import Any
87

@@ -16,7 +15,6 @@
1615
from mellea.backends.tools import (
1716
add_tools_from_context_actions,
1817
add_tools_from_model_options,
19-
get_tools_from_action,
2018
)
2119
from mellea.backends.types import ModelOption
2220
from mellea.helpers.fancy_logger import FancyLogger
@@ -299,13 +297,12 @@ def generate_from_chat_context(
299297
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
300298
)
301299
else:
302-
if isinstance(action, Component) and isinstance(
303-
action.format_for_llm(), TemplateRepresentation
304-
):
305-
tools = get_tools_from_action(action)
306-
307300
add_tools_from_model_options(tools, model_opts)
308301
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
302+
303+
# Add the tools from the action for this generation last so that
304+
# they overwrite conflicting names.
305+
add_tools_from_context_actions(tools, [action])
309306
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
310307

311308
# Generate a chat response from ollama, using the chat messages.

mellea/backends/openai.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
add_tools_from_context_actions,
2525
add_tools_from_model_options,
2626
convert_tools_to_json,
27-
get_tools_from_action,
2827
)
2928
from mellea.backends.types import ModelOption
3029
from mellea.helpers.fancy_logger import FancyLogger
@@ -35,7 +34,6 @@
3534
GenerateLog,
3635
ModelOutputThunk,
3736
ModelToolCall,
38-
TemplateRepresentation,
3937
)
4038
from mellea.stdlib.chat import Message
4139
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
@@ -409,14 +407,13 @@ def _generate_from_chat_context_standard(
409407
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
410408
)
411409
else:
412-
if isinstance(action, Component) and isinstance(
413-
action.format_for_llm(), TemplateRepresentation
414-
):
415-
tools = get_tools_from_action(action)
416-
417410
add_tools_from_model_options(tools, model_opts)
418411
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
419-
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
412+
413+
# Add the tools from the action for this generation last so that
414+
# they overwrite conflicting names.
415+
add_tools_from_context_actions(tools, [action])
416+
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
420417

421418
thinking = model_opts.get(ModelOption.THINKING, None)
422419
if type(thinking) is bool and thinking:

mellea/backends/tools.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,13 @@ def add_tools_from_context_actions(
5454
continue # Only components have template representations.
5555

5656
tr = action.format_for_llm()
57-
if isinstance(tr, str) or tr.tools is None:
57+
if not isinstance(tr, TemplateRepresentation) or tr.tools is None:
5858
continue
5959

6060
for tool_name, func in tr.tools.items():
6161
tools_dict[tool_name] = func
6262

6363

64-
def get_tools_from_action(action: Any) -> dict[str, Callable]:
65-
"""If an object is a Component with a TemplateRepresentation, grabs it's tools field.
66-
67-
Returns:
68-
dict: mapping function names to callables
69-
"""
70-
if isinstance(action, Component):
71-
tr = action.format_for_llm()
72-
if isinstance(tr, TemplateRepresentation):
73-
if tr.tools:
74-
return tr.tools
75-
76-
return {}
77-
78-
7964
def convert_tools_to_json(tools: dict[str, Callable]) -> list[dict]:
8065
"""Convert tools to json dict representation.
8166

mellea/backends/watsonx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
add_tools_from_context_actions,
1818
add_tools_from_model_options,
1919
convert_tools_to_json,
20-
get_tools_from_action,
2120
)
2221
from mellea.backends.types import ModelOption
2322
from mellea.helpers.fancy_logger import FancyLogger
@@ -266,9 +265,12 @@ def generate_from_chat_context(
266265
f"tool calling is superseded by format; will not call tools for request: {action}"
267266
)
268267
else:
269-
tools = get_tools_from_action(action)
270268
add_tools_from_model_options(tools, model_opts)
271269
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
270+
271+
# Add the tools from the action for this generation last so that
272+
# they overwrite conflicting names.
273+
add_tools_from_context_actions(tools, [action])
272274
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
273275

274276
formatted_tools = convert_tools_to_json(tools)

0 commit comments

Comments
 (0)