Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def _generate_from_context_standard(
input_ids = self._tokenizer.apply_chat_template( # type: ignore
ctx_as_conversation,
tools=convert_tools_to_json(tools), # type: ignore
add_generation_prompt=True,
return_tensors="pt",
**self._make_backend_specific_and_remove(model_options),
).to(self._device) # type: ignore
Expand Down Expand Up @@ -401,6 +402,7 @@ def _generate_from_context_standard(
self.post_processing,
conversation=ctx_as_conversation,
input_ids=input_ids,
format=format,
tool_calls=tool_calls,
tools=tools,
seed=seed,
Expand Down Expand Up @@ -457,6 +459,7 @@ async def post_processing(
self,
mot: ModelOutputThunk,
conversation: list[dict],
format: type[BaseModelSubclass] | None,
tool_calls: bool,
tools: dict[str, Callable],
seed,
Expand Down
2 changes: 2 additions & 0 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def _generate_from_chat_context_standard(
conversation=conversation,
tools=tools,
thinking=thinking,
format=format,
)

try:
Expand Down Expand Up @@ -369,6 +370,7 @@ async def post_processing(
conversation: list[dict],
tools: dict[str, Callable],
thinking,
format,
):
"""Called when generation is done."""
# Reconstruct the chat_response from chunks if streamed.
Expand Down
3 changes: 2 additions & 1 deletion mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def generate_from_chat_context(
# each processing step.
output._process = functools.partial(self.processing, tools=tools)
output._post_process = functools.partial(
self.post_processing, conversation=conversation, tools=tools
self.post_processing, conversation=conversation, tools=tools, format=format
)

try:
Expand Down Expand Up @@ -506,6 +506,7 @@ async def post_processing(
mot: ModelOutputThunk,
conversation: list[dict],
tools: dict[str, Callable],
format,
):
"""Called when generation is done."""
assert mot._action is not None, (
Expand Down
2 changes: 2 additions & 0 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def _generate_from_chat_context_standard(
conversation=conversation,
thinking=thinking,
seed=model_opts.get(ModelOption.SEED, None),
format=format,
)

try:
Expand Down Expand Up @@ -569,6 +570,7 @@ async def post_processing(
conversation: list[dict],
thinking,
seed,
format,
):
"""Called when generation is done."""
# Reconstruct the chat_response from chunks if streamed.
Expand Down
2 changes: 2 additions & 0 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def generate_from_chat_context(
conversation=conversation,
tools=tools,
seed=model_opts.get(ModelOption.SEED, None),
format=format,
)

try:
Expand Down Expand Up @@ -406,6 +407,7 @@ async def post_processing(
conversation: list[dict],
tools: dict[str, Callable],
seed,
format,
):
"""Called when generation is done."""
# Reconstruct the chat_response from chunks if streamed.
Expand Down
77 changes: 77 additions & 0 deletions test/backends/test_huggingface_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pydantic
import pytest
from typing_extensions import Annotated

from mellea import MelleaSession
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
from mellea.backends.cache import SimpleLRUCache
from mellea.backends.formatter import TemplateFormatter
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.types import ModelOption
from mellea.stdlib.base import CBlock, ChatContext
from mellea.stdlib.requirement import (
ALoraRequirement,
LLMaJRequirement,
Requirement,
ValidationResult,
default_output_to_bool,
)
import mellea.backends.model_ids as model_ids


@pytest.fixture(scope="module")
def backend():
"""Shared HuggingFace backend for all tests in this module."""
backend = LocalHFBackend(
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
cache=SimpleLRUCache(5),
)
# add_granite_aloras(backend)
return backend


@pytest.fixture(scope="function")
def session(backend):
"""Fresh HuggingFace session for each test."""
session = MelleaSession(backend, ctx=ChatContext())
yield session
session.reset()



@pytest.mark.qualitative
def test_tool(session):

tool_call_history = []
def get_temperature(location: str) -> int:
"""Returns today's temperature of the given city in Celsius.

Args:
location: a city name.
"""
tool_call_history.append(location)
return 21

output = session.instruct(
"What is today's temperature in Boston? Answer in Celsius. Reply the number only.",
model_options={
ModelOption.TOOLS: [get_temperature,],
ModelOption.MAX_NEW_TOKENS: 1000,
},
tool_calls = True,
)

assert output.tool_calls is not None

result = output.tool_calls["get_temperature"].call_func()
print(result)

assert len(tool_call_history) > 0
assert tool_call_history[0].lower() == "boston"
assert 21 == result


if __name__ == "__main__":
import pytest

pytest.main([__file__])