-
Notifications
You must be signed in to change notification settings - Fork 200
feat: LlamaStackChatGenerator update tools param to ToolsType
#2436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2bee6ee
04328de
e62ec32
19698a9
6d680d0
9bf7ed5
a1abdf7
8fb0145
5035af4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| import pytz | ||
| from haystack.components.generators.utils import print_streaming_chunk | ||
| from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall | ||
| from haystack.tools import Tool | ||
| from haystack.tools import Tool, Toolset | ||
| from openai.types.chat import ChatCompletion, ChatCompletionMessage | ||
| from openai.types.chat.chat_completion import Choice | ||
|
|
||
|
|
@@ -25,6 +25,11 @@ def weather(city: str): | |
| return f"The weather in {city} is sunny and 32°C" | ||
|
|
||
|
|
||
| def population(city: str): | ||
| """Get population for a given city.""" | ||
| return f"The population of {city} is 2.2 million" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def tools(): | ||
| tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} | ||
|
|
@@ -38,6 +43,25 @@ def tools(): | |
| return [tool] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mixed_tools(): | ||
| """Fixture that returns a mixed list of Tool and Toolset.""" | ||
| weather_tool = Tool( | ||
| name="weather", | ||
| description="useful to determine the weather in a given location", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=weather, | ||
| ) | ||
| population_tool = Tool( | ||
| name="population", | ||
| description="useful to determine the population of a given location", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=population, | ||
| ) | ||
| toolset = Toolset([population_tool]) | ||
| return [weather_tool, toolset] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_chat_completion(): | ||
| """ | ||
|
|
@@ -46,7 +70,7 @@ def mock_chat_completion(): | |
| with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: | ||
| completion = ChatCompletion( | ||
| id="foo", | ||
| model="llama3.2:3b", | ||
| model="ollama/llama3.2:3b", | ||
| object="chat.completion", | ||
| choices=[ | ||
| Choice( | ||
|
|
@@ -66,27 +90,27 @@ def mock_chat_completion(): | |
|
|
||
| class TestLlamaStackChatGenerator: | ||
| def test_init_default(self): | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b") | ||
| assert component.model == "llama3.2:3b" | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we changing the way we pass the model name? Is it necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to be - it was failing for the "llama3.2:3b". I looked at the CI run, the model was booted and ready. Then I tried this naming schema and it worked |
||
| assert component.model == "ollama/llama3.2:3b" | ||
| assert component.api_base_url == "http://localhost:8321/v1/openai/v1" | ||
| assert component.streaming_callback is None | ||
| assert not component.generation_kwargs | ||
|
|
||
| def test_init_with_parameters(self): | ||
| component = LlamaStackChatGenerator( | ||
| model="llama3.2:3b", | ||
| model="ollama/llama3.2:3b", | ||
| streaming_callback=print_streaming_chunk, | ||
| api_base_url="test-base-url", | ||
| generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
| ) | ||
| assert component.model == "llama3.2:3b" | ||
| assert component.model == "ollama/llama3.2:3b" | ||
| assert component.streaming_callback is print_streaming_chunk | ||
| assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | ||
|
|
||
| def test_to_dict_default( | ||
| self, | ||
| ): | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b") | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") | ||
| data = component.to_dict() | ||
|
|
||
| assert ( | ||
|
|
@@ -95,7 +119,7 @@ def test_to_dict_default( | |
| ) | ||
|
|
||
| expected_params = { | ||
| "model": "llama3.2:3b", | ||
| "model": "ollama/llama3.2:3b", | ||
| "streaming_callback": None, | ||
| "api_base_url": "http://localhost:8321/v1/openai/v1", | ||
| "generation_kwargs": {}, | ||
|
|
@@ -113,7 +137,7 @@ def test_to_dict_with_parameters( | |
| self, | ||
| ): | ||
| component = LlamaStackChatGenerator( | ||
| model="llama3.2:3b", | ||
| model="ollama/llama3.2:3b", | ||
| streaming_callback=print_streaming_chunk, | ||
| api_base_url="test-base-url", | ||
| generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
|
|
@@ -131,7 +155,7 @@ def test_to_dict_with_parameters( | |
| ) | ||
|
|
||
| expected_params = { | ||
| "model": "llama3.2:3b", | ||
| "model": "ollama/llama3.2:3b", | ||
| "api_base_url": "test-base-url", | ||
| "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | ||
| "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
|
|
@@ -153,7 +177,7 @@ def test_from_dict( | |
| "haystack_integrations.components.generators.llama_stack.chat.chat_generator.LlamaStackChatGenerator" | ||
| ), | ||
| "init_parameters": { | ||
| "model": "llama3.2:3b", | ||
| "model": "ollama/llama3.2:3b", | ||
| "api_base_url": "test-base-url", | ||
| "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | ||
| "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
|
|
@@ -165,7 +189,7 @@ def test_from_dict( | |
| }, | ||
| } | ||
| component = LlamaStackChatGenerator.from_dict(data) | ||
| assert component.model == "llama3.2:3b" | ||
| assert component.model == "ollama/llama3.2:3b" | ||
| assert component.streaming_callback is print_streaming_chunk | ||
| assert component.api_base_url == "test-base-url" | ||
| assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | ||
|
|
@@ -175,8 +199,33 @@ def test_from_dict( | |
| assert component.max_retries == 10 | ||
| assert not component.tools_strict | ||
|
|
||
| def test_init_with_mixed_tools(self): | ||
| def tool_fn(city: str) -> str: | ||
| return city | ||
|
|
||
| weather_tool = Tool( | ||
| name="weather", | ||
| description="Weather lookup", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=tool_fn, | ||
| ) | ||
| population_tool = Tool( | ||
| name="population", | ||
| description="Population lookup", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=tool_fn, | ||
| ) | ||
| toolset = Toolset([population_tool]) | ||
|
|
||
| generator = LlamaStackChatGenerator( | ||
| model="ollama/llama3.2:3b", | ||
| tools=[weather_tool, toolset], | ||
| ) | ||
|
|
||
| assert generator.tools == [weather_tool, toolset] | ||
|
|
||
| def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002 | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b") | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") | ||
| response = component.run(chat_messages) | ||
|
|
||
| # check that the component returns the correct ChatMessage response | ||
|
|
@@ -188,7 +237,7 @@ def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002 | |
|
|
||
| def test_run_with_params(self, chat_messages, mock_chat_completion): | ||
| component = LlamaStackChatGenerator( | ||
| model="llama3.2:3b", | ||
| model="ollama/llama3.2:3b", | ||
| generation_kwargs={"max_tokens": 10, "temperature": 0.5}, | ||
| ) | ||
| response = component.run(chat_messages) | ||
|
|
@@ -208,7 +257,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion): | |
| @pytest.mark.integration | ||
| def test_live_run(self): | ||
| chat_messages = [ChatMessage.from_user("What's the capital of France")] | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b") | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b") | ||
| results = component.run(chat_messages) | ||
| assert len(results["replies"]) == 1 | ||
| message: ChatMessage = results["replies"][0] | ||
|
|
@@ -228,7 +277,7 @@ def __call__(self, chunk: StreamingChunk) -> None: | |
| self.responses += chunk.content if chunk.content else "" | ||
|
|
||
| callback = Callback() | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b", streaming_callback=callback) | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", streaming_callback=callback) | ||
| results = component.run([ChatMessage.from_user("What's the capital of France?")]) | ||
|
|
||
| assert len(results["replies"]) == 1 | ||
|
|
@@ -244,7 +293,7 @@ def __call__(self, chunk: StreamingChunk) -> None: | |
| @pytest.mark.integration | ||
| def test_live_run_with_tools(self, tools): | ||
| chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b", tools=tools) | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=tools) | ||
| results = component.run(chat_messages) | ||
| assert len(results["replies"]) == 1 | ||
| message = results["replies"][0] | ||
|
|
@@ -263,7 +312,7 @@ def test_live_run_with_tools_and_response(self, tools): | |
| Integration test that the LlamaStackChatGenerator component can run with tools and get a response. | ||
| """ | ||
| initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] | ||
| component = LlamaStackChatGenerator(model="llama3.2:3b", tools=tools) | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=tools) | ||
| results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) | ||
|
|
||
| assert len(results["replies"]) > 0, "No replies received" | ||
|
|
@@ -298,3 +347,63 @@ def test_live_run_with_tools_and_response(self, tools): | |
| assert not final_message.tool_call | ||
| assert len(final_message.text) > 0 | ||
| assert "paris" in final_message.text.lower() | ||
|
|
||
| @pytest.mark.integration | ||
| def test_live_run_with_mixed_tools(self, mixed_tools): | ||
| """ | ||
| Integration test that verifies LlamaStackChatGenerator works with mixed Tool and Toolset. | ||
| This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset. | ||
| """ | ||
| initial_messages = [ | ||
| ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?") | ||
| ] | ||
| component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=mixed_tools) | ||
| results = component.run(messages=initial_messages) | ||
|
|
||
| assert len(results["replies"]) > 0, "No replies received" | ||
|
|
||
| # Find the message with tool calls | ||
| tool_call_message = None | ||
| for message in results["replies"]: | ||
| if message.tool_calls: | ||
| tool_call_message = message | ||
| break | ||
|
|
||
| assert tool_call_message is not None, "No message with tool call found" | ||
| assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance" | ||
| assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" | ||
|
|
||
| tool_calls = tool_call_message.tool_calls | ||
| assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}" | ||
|
|
||
| # Verify we got calls to both weather and population tools | ||
| tool_names = {tc.tool_name for tc in tool_calls} | ||
| assert "weather" in tool_names, "Expected 'weather' tool call" | ||
| assert "population" in tool_names, "Expected 'population' tool call" | ||
|
|
||
| # Verify tool call details | ||
| for tool_call in tool_calls: | ||
| assert tool_call.id, "Tool call does not contain value for 'id' key" | ||
| assert tool_call.tool_name in ["weather", "population"] | ||
| assert "city" in tool_call.arguments | ||
| assert tool_call.arguments["city"] in ["Paris", "Berlin"] | ||
| assert tool_call_message.meta["finish_reason"] == "tool_calls" | ||
|
|
||
| # Mock the response we'd get from ToolInvoker | ||
| tool_result_messages = [] | ||
| for tool_call in tool_calls: | ||
| if tool_call.tool_name == "weather": | ||
| result = "The weather in Paris is sunny and 32°C" | ||
| else: # population | ||
| result = "The population of Berlin is 2.2 million" | ||
| tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call)) | ||
|
|
||
| new_messages = [*initial_messages, tool_call_message, *tool_result_messages] | ||
| results = component.run(new_messages) | ||
|
|
||
| assert len(results["replies"]) == 1 | ||
| final_message = results["replies"][0] | ||
| assert not final_message.tool_call | ||
| assert len(final_message.text) > 0 | ||
| assert "paris" in final_message.text.lower() | ||
| assert "berlin" in final_message.text.lower() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Amnah199 please take a look
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what was happening is this look at the date stamp - 2 days ago. And it seems like llama stack doesn't have build command any more. So I switched to run command because that one made sense. Using run made a difference and the server started properly.
Then I faced this issue of names. Tests would fail saying there is no such model. I then switched to full naming schema and then everything started to work. cc @anakin87 @Amnah199