Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/meta_llama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.13.2"]
dependencies = ["haystack-ai>=2.19.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/meta_llama#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional

from haystack import component, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, StreamingCallbackT
from haystack.tools import Tool, Toolset
from haystack.tools import ToolsType
from haystack.utils import serialize_callable
from haystack.utils.auth import Secret

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
streaming_callback: Optional[StreamingCallbackT] = None,
api_base_url: Optional[str] = "https://api.llama.com/compat/v1/",
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
):
"""
Creates an instance of LlamaChatGenerator. Unless specified otherwise in the `model`, this is for Llama's
Expand Down Expand Up @@ -92,7 +92,8 @@ def __init__(
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling.
:param tools:
A list of tools for which the model can prepare calls.
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
"""
super(MetaLlamaChatGenerator, self).__init__( # noqa: UP008
api_key=api_key,
Expand All @@ -110,7 +111,7 @@ def _prepare_api_call(
messages: list[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
tools_strict: Optional[bool] = None,
) -> dict[str, Any]:
api_args = super(MetaLlamaChatGenerator, self)._prepare_api_call( # noqa: UP008
Expand Down
116 changes: 115 additions & 1 deletion integrations/meta_llama/tests/test_llama_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from haystack.components.generators.utils import print_streaming_chunk
from haystack.components.tools import ToolInvoker
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall
from haystack.tools import Tool
from haystack.tools import Tool, Toolset
from haystack.utils.auth import Secret
from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionMessage
Expand All @@ -33,6 +33,11 @@ def weather(city: str):
return f"The weather in {city} is sunny and 32°C"


def population(city: str):
"""Get population for a given city."""
return f"The population of {city} is 2.2 million"


@pytest.fixture
def tools():
tool_parameters = {
Expand All @@ -50,6 +55,25 @@ def tools():
return [tool]


@pytest.fixture
def mixed_tools():
"""Fixture that returns a mixed list of Tool and Toolset."""
weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather,
)
population_tool = Tool(
name="population",
description="useful to determine the population of a given location",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=population,
)
toolset = Toolset([population_tool])
return [weather_tool, toolset]


@pytest.fixture
def mock_chat_completion():
"""
Expand Down Expand Up @@ -538,3 +562,93 @@ def test_serde_in_pipeline(self, monkeypatch):
assert loaded_generator.tools[0].name == generator.tools[0].name
assert loaded_generator.tools[0].description == generator.tools[0].description
assert loaded_generator.tools[0].parameters == generator.tools[0].parameters

def test_init_with_mixed_tools(self, monkeypatch):
def tool_fn(city: str) -> str:
return city

weather_tool = Tool(
name="weather",
description="Weather lookup",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=tool_fn,
)
population_tool = Tool(
name="population",
description="Population lookup",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=tool_fn,
)
toolset = Toolset([population_tool])

monkeypatch.setenv("LLAMA_API_KEY", "test-api-key")
generator = MetaLlamaChatGenerator(
model="Llama-4-Scout-17B-16E-Instruct-FP8",
tools=[weather_tool, toolset],
)

assert generator.tools == [weather_tool, toolset]

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_mixed_tools(self, mixed_tools):
"""
Integration test that verifies MetaLlamaChatGenerator works with mixed Tool and Toolset.
This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset.
"""
initial_messages = [
ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?")
]
component = MetaLlamaChatGenerator(tools=mixed_tools)
results = component.run(messages=initial_messages)

assert len(results["replies"]) > 0, "No replies received"

# Find the message with tool calls
tool_call_message = None
for message in results["replies"]:
if message.tool_calls:
tool_call_message = message
break

assert tool_call_message is not None, "No message with tool call found"
assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance"
assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"

tool_calls = tool_call_message.tool_calls
assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}"

# Verify we got calls to both weather and population tools
tool_names = {tc.tool_name for tc in tool_calls}
assert "weather" in tool_names, "Expected 'weather' tool call"
assert "population" in tool_names, "Expected 'population' tool call"

# Verify tool call details
for tool_call in tool_calls:
assert tool_call.id, "Tool call does not contain value for 'id' key"
assert tool_call.tool_name in ["weather", "population"]
assert "city" in tool_call.arguments
assert tool_call.arguments["city"] in ["Paris", "Berlin"]
assert tool_call_message.meta["finish_reason"] == "tool_calls"

# Mock the response we'd get from ToolInvoker
tool_result_messages = []
for tool_call in tool_calls:
if tool_call.tool_name == "weather":
result = "The weather in Paris is sunny and 32°C"
else: # population
result = "The population of Berlin is 2.2 million"
tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call))

new_messages = [*initial_messages, tool_call_message, *tool_result_messages]
results = component.run(new_messages)

assert len(results["replies"]) == 1
final_message = results["replies"][0]
assert not final_message.tool_call
assert len(final_message.text) > 0
assert "paris" in final_message.text.lower()
assert "berlin" in final_message.text.lower()
Loading