diff --git a/integrations/meta_llama/pyproject.toml b/integrations/meta_llama/pyproject.toml index 56790d4f5a..b2c13d2d19 100644 --- a/integrations/meta_llama/pyproject.toml +++ b/integrations/meta_llama/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.13.2"] +dependencies = ["haystack-ai>=2.19.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/meta_llama#readme" diff --git a/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py b/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py index c0b2353428..37edb2187b 100644 --- a/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py +++ b/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py @@ -3,12 +3,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional from haystack import component, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, StreamingCallbackT -from haystack.tools import Tool, Toolset +from haystack.tools import ToolsType from haystack.utils import serialize_callable from haystack.utils.auth import Secret @@ -60,7 +60,7 @@ def __init__( streaming_callback: Optional[StreamingCallbackT] = None, api_base_url: Optional[str] = "https://api.llama.com/compat/v1/", generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[Union[List[Tool], Toolset]] = None, + tools: Optional[ToolsType] = None, ): """ Creates an instance of LlamaChatGenerator. Unless specified otherwise in the `model`, this is for Llama's @@ -92,7 +92,8 @@ def __init__( - `safe_prompt`: Whether to inject a safety prompt before all conversations. - `random_seed`: The seed to use for random sampling. :param tools: - A list of tools for which the model can prepare calls. + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + Each tool should have a unique name. """ super(MetaLlamaChatGenerator, self).__init__( # noqa: UP008 api_key=api_key, @@ -110,7 +111,7 @@ def _prepare_api_call( messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[dict[str, Any]] = None, - tools: Optional[Union[list[Tool], Toolset]] = None, + tools: Optional[ToolsType] = None, tools_strict: Optional[bool] = None, ) -> dict[str, Any]: api_args = super(MetaLlamaChatGenerator, self)._prepare_api_call( # noqa: UP008 diff --git a/integrations/meta_llama/tests/test_llama_chat_generator.py b/integrations/meta_llama/tests/test_llama_chat_generator.py index 713fe46a9a..ef5dd36ac5 100644 --- a/integrations/meta_llama/tests/test_llama_chat_generator.py +++ b/integrations/meta_llama/tests/test_llama_chat_generator.py @@ -10,7 +10,7 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall -from haystack.tools import Tool +from haystack.tools import Tool, Toolset from haystack.utils.auth import Secret from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionMessage @@ -33,6 +33,11 @@ def weather(city: str): return f"The weather in {city} is sunny and 32°C" +def population(city: str): + """Get population for a given city.""" + return f"The population of {city} is 2.2 million" + + @pytest.fixture def tools(): tool_parameters = { @@ -50,6 +55,25 @@ def tools(): return [tool] +@pytest.fixture +def mixed_tools(): + """Fixture that returns a mixed list of Tool and Toolset.""" + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=weather, + ) + population_tool = Tool( + name="population", + description="useful to determine the population of a given location", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=population, + ) + toolset = Toolset([population_tool]) + return [weather_tool, toolset] + + @pytest.fixture def mock_chat_completion(): """ @@ -538,3 +562,95 @@ def test_serde_in_pipeline(self, monkeypatch): assert loaded_generator.tools[0].name == generator.tools[0].name assert loaded_generator.tools[0].description == generator.tools[0].description assert loaded_generator.tools[0].parameters == generator.tools[0].parameters + + def test_init_with_mixed_tools(self, monkeypatch): + def tool_fn(city: str) -> str: + return city + + weather_tool = Tool( + name="weather", + description="Weather lookup", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=tool_fn, + ) + population_tool = Tool( + name="population", + description="Population lookup", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=tool_fn, + ) + toolset = Toolset([population_tool]) + + monkeypatch.setenv("LLAMA_API_KEY", "test-api-key") + generator = MetaLlamaChatGenerator( + model="Llama-4-Scout-17B-16E-Instruct-FP8", + tools=[weather_tool, toolset], + ) + + assert generator.tools == [weather_tool, toolset] + + @pytest.mark.skipif( + not os.environ.get("LLAMA_API_KEY", None), + reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_mixed_tools(self, mixed_tools): + """ + Integration test that verifies MetaLlamaChatGenerator works with mixed Tool and Toolset. + This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset. + """ + initial_messages = [ + ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?") + ] + # default model Llama-4-Scout-17B-16E-Instruct-FP8 can't handle multiple tool responses well + # we use stronger model Llama-3.3-70B-Instruct for this test + component = MetaLlamaChatGenerator(model="Llama-3.3-70B-Instruct", tools=mixed_tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_call_message = None + for message in results["replies"]: + if message.tool_calls: + tool_call_message = message + break + + assert tool_call_message is not None, "No message with tool call found" + assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}" + + # Verify we got calls to both weather and population tools + tool_names = {tc.tool_name for tc in tool_calls} + assert "weather" in tool_names, "Expected 'weather' tool call" + assert "population" in tool_names, "Expected 'population' tool call" + + # Verify tool call details + for tool_call in tool_calls: + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name in ["weather", "population"] + assert "city" in tool_call.arguments + assert tool_call.arguments["city"] in ["Paris", "Berlin"] + assert tool_call_message.meta["finish_reason"] == "tool_calls" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [] + for tool_call in tool_calls: + if tool_call.tool_name == "weather": + result = "The weather in Paris is sunny and 32°C" + else: # population + result = "The population of Berlin is 2.2 million" + tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call)) + + new_messages = [*initial_messages, tool_call_message, *tool_result_messages] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower()