diff --git a/.github/workflows/llama_stack.yml b/.github/workflows/llama_stack.yml index 1faa4d1909..08c51802d3 100644 --- a/.github/workflows/llama_stack.yml +++ b/.github/workflows/llama_stack.yml @@ -82,13 +82,30 @@ jobs: OLLAMA_URL: http://localhost:11434 shell: bash run: | - pip install uv - uv run --with llama-stack llama stack build --distro starter --image-type venv --run < /dev/null > server.log 2>&1 & - sleep 120 - # Verify it's running - curl -f http://localhost:8321/v1/models || { cat server.log; exit 1; } - - echo "Llama Stack Server started successfully." + set -euo pipefail + pip install -q uv + + # Install the starter distro's deps into the uv environment + uv run --with llama-stack bash -lc 'llama stack list-deps starter | xargs -L1 uv pip install' + + # Start Llama Stack (no more --image-type flag) + uv run --with llama-stack llama stack run starter > server.log 2>&1 & + SERVER_PID=$! + + # Wait up to ~120s for health; fail fast if process dies + for i in {1..60}; do + if curl -fsS http://localhost:8321/v1/models >/dev/null; then + echo "Llama Stack Server started successfully." + break + fi + if ! kill -0 "$SERVER_PID" 2>/dev/null; then + echo "Server exited early. Logs:"; cat server.log; exit 1 + fi + sleep 2 + done + + # Final health check + curl -fsS http://localhost:8321/v1/models || { echo "Health check failed. Logs:"; cat server.log; exit 1; } - name: Install Hatch run: pip install --upgrade hatch diff --git a/integrations/llama_stack/examples/llama-stack-with-tools.py b/integrations/llama_stack/examples/llama-stack-with-tools.py index 75da9955e4..2f078a955a 100644 --- a/integrations/llama_stack/examples/llama-stack-with-tools.py +++ b/integrations/llama_stack/examples/llama-stack-with-tools.py @@ -34,7 +34,7 @@ def weather(city: str): tool_invoker = ToolInvoker(tools=[weather_tool]) client = LlamaStackChatGenerator( - model="llama3.2:3b", # model name varies depending on the inference provider used for the Llama Stack Server. + model="ollama/llama3.2:3b", # model depends on the inference provider used for the Llama Stack Server. api_base_url="http://localhost:8321/v1/openai/v1", ) messages = [ChatMessage.from_user("What's the weather in Tokyo?")] diff --git a/integrations/llama_stack/pyproject.toml b/integrations/llama_stack/pyproject.toml index 8f4e4b5cba..fa477bcc1b 100644 --- a/integrations/llama_stack/pyproject.toml +++ b/integrations/llama_stack/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.14", "llama-stack>=0.2.17"] +dependencies = ["haystack-ai>=2.19.0", "llama-stack>=0.2.17"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama-stack#readme" diff --git a/integrations/llama_stack/src/haystack_integrations/components/generators/llama_stack/chat/chat_generator.py b/integrations/llama_stack/src/haystack_integrations/components/generators/llama_stack/chat/chat_generator.py index 7120795f6d..10adb009e3 100644 --- a/integrations/llama_stack/src/haystack_integrations/components/generators/llama_stack/chat/chat_generator.py +++ b/integrations/llama_stack/src/haystack_integrations/components/generators/llama_stack/chat/chat_generator.py @@ -2,12 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingCallbackT -from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset +from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset from haystack.utils import deserialize_callable, serialize_callable from haystack.utils.auth import Secret @@ -41,7 +41,7 @@ class LlamaStackChatGenerator(OpenAIChatGenerator): messages = [ChatMessage.from_user("What's Natural Language Processing?")] - client = LlamaStackChatGenerator(model="llama3.2:3b") + client = LlamaStackChatGenerator(model="ollama/llama3.2:3b") response = client.run(messages) print(response) @@ -49,7 +49,7 @@ class LlamaStackChatGenerator(OpenAIChatGenerator): is a branch of artificial intelligence >>that focuses on enabling computers to understand, interpret, and generate human language in a way that is >>meaningful and useful.')], _role=, _name=None, - >>_meta={'model': 'llama3.2:3b', 'index': 0, 'finish_reason': 'stop', + >>_meta={'model': 'ollama/llama3.2:3b', 'index': 0, 'finish_reason': 'stop', >>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]} """ @@ -62,7 +62,7 @@ def __init__( streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, - tools: Optional[Union[List[Tool], Toolset]] = None, + tools: Optional[ToolsType] = None, tools_strict: bool = False, max_retries: Optional[int] = None, http_client_kwargs: Optional[Dict[str, Any]] = None, @@ -98,8 +98,8 @@ def __init__( Timeout for client calls using OpenAI API. If not set, it defaults to either the `OPENAI_TIMEOUT` environment variable, or 30 seconds. :param tools: - A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a - list of `Tool` objects or a `Toolset` instance. + A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + Each tool should have a unique name. :param tools_strict: Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly the schema provided in the `parameters` field of the tool definition, but this may increase latency. diff --git a/integrations/llama_stack/tests/test_llama_stack_chat_generator.py b/integrations/llama_stack/tests/test_llama_stack_chat_generator.py index e01263a137..4723e9ebb2 100644 --- a/integrations/llama_stack/tests/test_llama_stack_chat_generator.py +++ b/integrations/llama_stack/tests/test_llama_stack_chat_generator.py @@ -5,7 +5,7 @@ import pytz from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall -from haystack.tools import Tool +from haystack.tools import Tool, Toolset from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -25,6 +25,11 @@ def weather(city: str): return f"The weather in {city} is sunny and 32°C" +def population(city: str): + """Get population for a given city.""" + return f"The population of {city} is 2.2 million" + + @pytest.fixture def tools(): tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} @@ -38,6 +43,25 @@ def tools(): return [tool] +@pytest.fixture +def mixed_tools(): + """Fixture that returns a mixed list of Tool and Toolset.""" + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=weather, + ) + population_tool = Tool( + name="population", + description="useful to determine the population of a given location", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=population, + ) + toolset = Toolset([population_tool]) + return [weather_tool, toolset] + + @pytest.fixture def mock_chat_completion(): """ @@ -46,7 +70,7 @@ def mock_chat_completion(): with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletion( id="foo", - model="llama3.2:3b", + model="ollama/llama3.2:3b", object="chat.completion", choices=[ Choice( @@ -66,27 +90,27 @@ def mock_chat_completion(): class TestLlamaStackChatGenerator: def test_init_default(self): - component = LlamaStackChatGenerator(model="llama3.2:3b") - assert component.model == "llama3.2:3b" + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") + assert component.model == "ollama/llama3.2:3b" assert component.api_base_url == "http://localhost:8321/v1/openai/v1" assert component.streaming_callback is None assert not component.generation_kwargs def test_init_with_parameters(self): component = LlamaStackChatGenerator( - model="llama3.2:3b", + model="ollama/llama3.2:3b", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) - assert component.model == "llama3.2:3b" + assert component.model == "ollama/llama3.2:3b" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default( self, ): - component = LlamaStackChatGenerator(model="llama3.2:3b") + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") data = component.to_dict() assert ( @@ -95,7 +119,7 @@ def test_to_dict_default( ) expected_params = { - "model": "llama3.2:3b", + "model": "ollama/llama3.2:3b", "streaming_callback": None, "api_base_url": "http://localhost:8321/v1/openai/v1", "generation_kwargs": {}, @@ -113,7 +137,7 @@ def test_to_dict_with_parameters( self, ): component = LlamaStackChatGenerator( - model="llama3.2:3b", + model="ollama/llama3.2:3b", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, @@ -131,7 +155,7 @@ def test_to_dict_with_parameters( ) expected_params = { - "model": "llama3.2:3b", + "model": "ollama/llama3.2:3b", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -153,7 +177,7 @@ def test_from_dict( "haystack_integrations.components.generators.llama_stack.chat.chat_generator.LlamaStackChatGenerator" ), "init_parameters": { - "model": "llama3.2:3b", + "model": "ollama/llama3.2:3b", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -165,7 +189,7 @@ def test_from_dict( }, } component = LlamaStackChatGenerator.from_dict(data) - assert component.model == "llama3.2:3b" + assert component.model == "ollama/llama3.2:3b" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @@ -175,8 +199,33 @@ def test_from_dict( assert component.max_retries == 10 assert not component.tools_strict + def test_init_with_mixed_tools(self): + def tool_fn(city: str) -> str: + return city + + weather_tool = Tool( + name="weather", + description="Weather lookup", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=tool_fn, + ) + population_tool = Tool( + name="population", + description="Population lookup", + parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + function=tool_fn, + ) + toolset = Toolset([population_tool]) + + generator = LlamaStackChatGenerator( + model="ollama/llama3.2:3b", + tools=[weather_tool, toolset], + ) + + assert generator.tools == [weather_tool, toolset] + def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002 - component = LlamaStackChatGenerator(model="llama3.2:3b") + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") response = component.run(chat_messages) # check that the component returns the correct ChatMessage response @@ -188,7 +237,7 @@ def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002 def test_run_with_params(self, chat_messages, mock_chat_completion): component = LlamaStackChatGenerator( - model="llama3.2:3b", + model="ollama/llama3.2:3b", generation_kwargs={"max_tokens": 10, "temperature": 0.5}, ) response = component.run(chat_messages) @@ -208,7 +257,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion): @pytest.mark.integration def test_live_run(self): chat_messages = [ChatMessage.from_user("What's the capital of France")] - component = LlamaStackChatGenerator(model="llama3.2:3b") + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -228,7 +277,7 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - component = LlamaStackChatGenerator(model="llama3.2:3b", streaming_callback=callback) + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", streaming_callback=callback) results = component.run([ChatMessage.from_user("What's the capital of France?")]) assert len(results["replies"]) == 1 @@ -244,7 +293,7 @@ def __call__(self, chunk: StreamingChunk) -> None: @pytest.mark.integration def test_live_run_with_tools(self, tools): chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] - component = LlamaStackChatGenerator(model="llama3.2:3b", tools=tools) + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=tools) results = component.run(chat_messages) assert len(results["replies"]) == 1 message = results["replies"][0] @@ -263,7 +312,7 @@ def test_live_run_with_tools_and_response(self, tools): Integration test that the LlamaStackChatGenerator component can run with tools and get a response. """ initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] - component = LlamaStackChatGenerator(model="llama3.2:3b", tools=tools) + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=tools) results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) assert len(results["replies"]) > 0, "No replies received" @@ -298,3 +347,63 @@ def test_live_run_with_tools_and_response(self, tools): assert not final_message.tool_call assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() + + @pytest.mark.integration + def test_live_run_with_mixed_tools(self, mixed_tools): + """ + Integration test that verifies LlamaStackChatGenerator works with mixed Tool and Toolset. + This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset. + """ + initial_messages = [ + ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?") + ] + component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=mixed_tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_call_message = None + for message in results["replies"]: + if message.tool_calls: + tool_call_message = message + break + + assert tool_call_message is not None, "No message with tool call found" + assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_calls = tool_call_message.tool_calls + assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}" + + # Verify we got calls to both weather and population tools + tool_names = {tc.tool_name for tc in tool_calls} + assert "weather" in tool_names, "Expected 'weather' tool call" + assert "population" in tool_names, "Expected 'population' tool call" + + # Verify tool call details + for tool_call in tool_calls: + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name in ["weather", "population"] + assert "city" in tool_call.arguments + assert tool_call.arguments["city"] in ["Paris", "Berlin"] + assert tool_call_message.meta["finish_reason"] == "tool_calls" + + # Mock the response we'd get from ToolInvoker + tool_result_messages = [] + for tool_call in tool_calls: + if tool_call.tool_name == "weather": + result = "The weather in Paris is sunny and 32°C" + else: # population + result = "The population of Berlin is 2.2 million" + tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call)) + + new_messages = [*initial_messages, tool_call_message, *tool_result_messages] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower()