diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index ae8bb249..5903b9c2 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -332,6 +332,7 @@ def _generate_from_context_standard( input_ids = self._tokenizer.apply_chat_template( # type: ignore ctx_as_conversation, tools=convert_tools_to_json(tools), # type: ignore + add_generation_prompt=True, return_tensors="pt", **self._make_backend_specific_and_remove(model_options), ).to(self._device) # type: ignore @@ -401,6 +402,7 @@ def _generate_from_context_standard( self.post_processing, conversation=ctx_as_conversation, input_ids=input_ids, + format=format, tool_calls=tool_calls, tools=tools, seed=seed, @@ -457,6 +459,7 @@ async def post_processing( self, mot: ModelOutputThunk, conversation: list[dict], + format: type[BaseModelSubclass] | None, tool_calls: bool, tools: dict[str, Callable], seed, diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 93a1ef5d..f665b985 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -293,6 +293,7 @@ def _generate_from_chat_context_standard( conversation=conversation, tools=tools, thinking=thinking, + format=format, ) try: @@ -369,6 +370,7 @@ async def post_processing( conversation: list[dict], tools: dict[str, Callable], thinking, + format, ): """Called when generation is done.""" # Reconstruct the chat_response from chunks if streamed. diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 9c16d0d6..a9e779b5 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -343,7 +343,7 @@ def generate_from_chat_context( # each processing step. output._process = functools.partial(self.processing, tools=tools) output._post_process = functools.partial( - self.post_processing, conversation=conversation, tools=tools + self.post_processing, conversation=conversation, tools=tools, format=format ) try: @@ -506,6 +506,7 @@ async def post_processing( mot: ModelOutputThunk, conversation: list[dict], tools: dict[str, Callable], + format, ): """Called when generation is done.""" assert mot._action is not None, ( diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 2164eafd..fe6b1505 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -502,6 +502,7 @@ def _generate_from_chat_context_standard( conversation=conversation, thinking=thinking, seed=model_opts.get(ModelOption.SEED, None), + format=format, ) try: @@ -569,6 +570,7 @@ async def post_processing( conversation: list[dict], thinking, seed, + format, ): """Called when generation is done.""" # Reconstruct the chat_response from chunks if streamed. diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 155773aa..15cbf1f6 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -340,6 +340,7 @@ def generate_from_chat_context( conversation=conversation, tools=tools, seed=model_opts.get(ModelOption.SEED, None), + format=format, ) try: @@ -406,6 +407,7 @@ async def post_processing( conversation: list[dict], tools: dict[str, Callable], seed, + format, ): """Called when generation is done.""" # Reconstruct the chat_response from chunks if streamed. diff --git a/test/backends/test_huggingface_tools.py b/test/backends/test_huggingface_tools.py new file mode 100644 index 00000000..cc88ec0a --- /dev/null +++ b/test/backends/test_huggingface_tools.py @@ -0,0 +1,77 @@ +import pydantic +import pytest +from typing_extensions import Annotated + +from mellea import MelleaSession +from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras +from mellea.backends.cache import SimpleLRUCache +from mellea.backends.formatter import TemplateFormatter +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, ChatContext +from mellea.stdlib.requirement import ( + ALoraRequirement, + LLMaJRequirement, + Requirement, + ValidationResult, + default_output_to_bool, +) +import mellea.backends.model_ids as model_ids + + +@pytest.fixture(scope="module") +def backend(): + """Shared HuggingFace backend for all tests in this module.""" + backend = LocalHFBackend( + model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B, + cache=SimpleLRUCache(5), + ) + # add_granite_aloras(backend) + return backend + + +@pytest.fixture(scope="function") +def session(backend): + """Fresh HuggingFace session for each test.""" + session = MelleaSession(backend, ctx=ChatContext()) + yield session + session.reset() + + + +@pytest.mark.qualitative +def test_tool(session): + + tool_call_history = [] + def get_temperature(location: str) -> int: + """Returns today's temperature of the given city in Celsius. + + Args: + location: a city name. + """ + tool_call_history.append(location) + return 21 + + output = session.instruct( + "What is today's temperature in Boston? Answer in Celsius. Reply the number only.", + model_options={ + ModelOption.TOOLS: [get_temperature,], + ModelOption.MAX_NEW_TOKENS: 1000, + }, + tool_calls = True, + ) + + assert output.tool_calls is not None + + result = output.tool_calls["get_temperature"].call_func() + print(result) + + assert len(tool_call_history) > 0 + assert tool_call_history[0].lower() == "boston" + assert 21 == result + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__])