Skip to content

Commit 4dea34f

Browse files
committed
add tools from context to backend calls; add test for tool from context calling
1 parent d766a3e commit 4dea34f

File tree

5 files changed

+38
-2
lines changed

5 files changed

+38
-2
lines changed

mellea/backends/huggingface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
3333
from mellea.backends.model_ids import ModelIdentifier
3434
from mellea.backends.tools import (
35+
add_tools_from_context_actions,
3536
add_tools_from_model_options,
3637
convert_tools_to_json,
3738
get_tools_from_action,
@@ -331,6 +332,7 @@ def _generate_from_context_standard(
331332
):
332333
tools = get_tools_from_action(action)
333334
add_tools_from_model_options(tools, model_options)
335+
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
334336
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
335337

336338
seed = model_options.get(ModelOption.SEED, None)

mellea/backends/ollama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from mellea.backends import BaseModelSubclass
1414
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
1515
from mellea.backends.model_ids import ModelIdentifier
16-
from mellea.backends.tools import add_tools_from_model_options, get_tools_from_action
16+
from mellea.backends.tools import (
17+
add_tools_from_context_actions,
18+
add_tools_from_model_options,
19+
get_tools_from_action,
20+
)
1721
from mellea.backends.types import ModelOption
1822
from mellea.helpers.fancy_logger import FancyLogger
1923
from mellea.stdlib.base import (
@@ -301,6 +305,7 @@ def generate_from_chat_context(
301305
tools = get_tools_from_action(action)
302306

303307
add_tools_from_model_options(tools, model_opts)
308+
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
304309
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
305310

306311
# Generate a chat response from ollama, using the chat messages.

mellea/backends/openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
2222
from mellea.backends.model_ids import ModelIdentifier
2323
from mellea.backends.tools import (
24+
add_tools_from_context_actions,
2425
add_tools_from_model_options,
2526
convert_tools_to_json,
2627
get_tools_from_action,
@@ -414,6 +415,7 @@ def _generate_from_chat_context_standard(
414415
tools = get_tools_from_action(action)
415416

416417
add_tools_from_model_options(tools, model_opts)
418+
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
417419
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
418420

419421
thinking = model_opts.get(ModelOption.THINKING, None)

mellea/backends/watsonx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
1515
from mellea.backends.model_ids import ModelIdentifier
1616
from mellea.backends.tools import (
17+
add_tools_from_context_actions,
1718
add_tools_from_model_options,
1819
convert_tools_to_json,
1920
get_tools_from_action,
@@ -267,7 +268,7 @@ def generate_from_chat_context(
267268
else:
268269
tools = get_tools_from_action(action)
269270
add_tools_from_model_options(tools, model_opts)
270-
271+
add_tools_from_context_actions(tools, ctx.actions_for_available_tools())
271272
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
272273

273274
formatted_tools = convert_tools_to_json(tools)

test/test_tool_calls.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,34 @@ def test_add_tools_from_context_actions():
119119
assert tools["tool2"] == ftc1.tool2, f"{tools["tool2"]} should == {ftc1.tool2}"
120120

121121

122+
def test_tool_called_from_context_action(m: MelleaSession, table: Table):
123+
"""Make sure tools can be called from actions in the context."""
124+
r = 10
125+
m.ctx.reset()
126+
127+
# Insert a component with tools into the context.
128+
m.ctx.insert(table)
129+
130+
returned_tool = False
131+
for i in range(r):
132+
# Make sure the specific generate call is on a different action with
133+
# no tools to make sure it's a tool from the context.
134+
result = m.backend.generate_from_context(
135+
CBlock("Add a row to the table."),
136+
m.ctx,
137+
tool_calls=True
138+
)
139+
if result.tool_calls is not None and len(result.tool_calls) > 0:
140+
returned_tool = True
141+
break
142+
143+
assert returned_tool, f"did not return a tool after {r} attempts"
144+
145+
122146
def test_tool_called(m: MelleaSession, table: Table):
123147
"""We don't force tools to be called. As a result, this test might unexpectedly fail."""
124148
r = 10
149+
m.ctx.reset()
125150

126151
returned_tool = False
127152
for i in range(r):
@@ -136,6 +161,7 @@ def test_tool_called(m: MelleaSession, table: Table):
136161
def test_tool_not_called(m: MelleaSession, table: Table):
137162
"""Ensure tools aren't always called when provided."""
138163
r = 10
164+
m.ctx.reset()
139165

140166
returned_no_tool = False
141167
for i in range(r):

0 commit comments

Comments
 (0)