-
Notifications
You must be signed in to change notification settings - Fork 198
feat: NvidiaChatGenerator add integration tests for mixing Tool/Toolset
#2422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| import pytz | ||
| from haystack.components.generators.utils import print_streaming_chunk | ||
| from haystack.dataclasses import ChatMessage, StreamingChunk | ||
| from haystack.tools import Tool | ||
| from haystack.tools import Tool, Toolset | ||
| from haystack.utils.auth import Secret | ||
| from openai import AsyncOpenAI, OpenAIError | ||
| from openai.types.chat import ChatCompletion, ChatCompletionMessage | ||
|
|
@@ -329,6 +329,114 @@ def test_live_run_with_json_object(self): | |
| assert isinstance(output["rating"], int) | ||
| assert "Inception" in output["title"] | ||
|
|
||
| @pytest.mark.skipif( | ||
| not os.environ.get("NVIDIA_API_KEY", None), | ||
| reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", | ||
| ) | ||
| @pytest.mark.integration | ||
| def test_integration_mixing_init_and_runtime_tools(self): | ||
| """Test mixing tools from init and runtime by passing tools to both __init__ and run().""" | ||
|
|
||
| def weather_function(city: str) -> str: | ||
| """Get weather information for a city.""" | ||
| return f"Weather in {city}: 22°C, sunny" | ||
|
|
||
| def time_function(city: str) -> str: | ||
| """Get current time in a city.""" | ||
| return f"Current time in {city}: 14:30" | ||
|
|
||
| # Create tools | ||
| weather_tool = Tool( | ||
| name="weather", | ||
| description="Get weather information for a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=weather_function, | ||
| ) | ||
|
|
||
| time_tool = Tool( | ||
| name="time", | ||
| description="Get current time in a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=time_function, | ||
| ) | ||
|
|
||
| # Initialize with weather_tool | ||
| component = NvidiaChatGenerator(tools=[weather_tool]) | ||
|
|
||
| # Pass time_tool at runtime - runtime tools should take precedence | ||
| messages = [ChatMessage.from_user("What's the time in Tokyo?")] | ||
| results = component.run(messages, tools=[time_tool]) | ||
|
|
||
| assert len(results["replies"]) == 1 | ||
| message = results["replies"][0] | ||
|
|
||
| # Should use time_tool since it was passed at runtime | ||
| assert message.tool_calls is not None | ||
| tool_call = message.tool_calls[0] | ||
| assert tool_call.tool_name == "time" | ||
| assert tool_call.arguments == {"city": "Tokyo"} | ||
|
|
||
| @pytest.mark.skipif( | ||
| not os.environ.get("NVIDIA_API_KEY", None), | ||
| reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", | ||
| ) | ||
| @pytest.mark.integration | ||
| def test_integration_mixing_tools_and_toolset(self): | ||
| """Test mixing Tool list and Toolset at runtime.""" | ||
|
|
||
| def weather_function(city: str) -> str: | ||
| """Get weather information for a city.""" | ||
| return f"Weather in {city}: 22°C, sunny" | ||
|
|
||
| def time_function(city: str) -> str: | ||
| """Get current time in a city.""" | ||
| return f"Current time in {city}: 14:30" | ||
|
|
||
| def echo_function(text: str) -> str: | ||
| """Echo a text.""" | ||
| return text | ||
|
|
||
| # Create tools | ||
| weather_tool = Tool( | ||
| name="weather", | ||
| description="Get weather information for a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=weather_function, | ||
| ) | ||
|
|
||
| time_tool = Tool( | ||
| name="time", | ||
| description="Get current time in a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=time_function, | ||
| ) | ||
|
|
||
| echo_tool = Tool( | ||
| name="echo", | ||
| description="Echo a text", | ||
| parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, | ||
| function=echo_function, | ||
| ) | ||
|
|
||
| # Create Toolset with weather and time tools | ||
| toolset = Toolset([weather_tool, time_tool]) | ||
|
|
||
| # Initialize with toolset | ||
| component = NvidiaChatGenerator(tools=toolset) | ||
|
|
||
| # Pass echo_tool as a list at runtime - runtime tools should take precedence | ||
| messages = [ChatMessage.from_user("Echo this: Hello World")] | ||
| results = component.run(messages, tools=[echo_tool]) | ||
|
||
|
|
||
| assert len(results["replies"]) == 1 | ||
| message = results["replies"][0] | ||
|
|
||
| # Should use echo_tool since it was passed at runtime | ||
| assert message.tool_calls is not None | ||
| tool_call = message.tool_calls[0] | ||
| assert tool_call.tool_name == "echo" | ||
| assert tool_call.arguments == {"text": "Hello World"} | ||
|
|
||
|
|
||
| class TestNvidiaChatGeneratorAsync: | ||
| def test_init_default_async(self, monkeypatch): | ||
|
|
@@ -439,3 +547,113 @@ async def callback(chunk: StreamingChunk): | |
|
|
||
| assert counter > 1 | ||
| assert "Paris" in responses | ||
|
|
||
| @pytest.mark.skipif( | ||
| not os.environ.get("NVIDIA_API_KEY", None), | ||
| reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", | ||
| ) | ||
| @pytest.mark.integration | ||
| @pytest.mark.asyncio | ||
| async def test_integration_mixing_init_and_runtime_tools_async(self): | ||
| """Test mixing tools from init and runtime in async mode.""" | ||
|
|
||
| def weather_function(city: str) -> str: | ||
| """Get weather information for a city.""" | ||
| return f"Weather in {city}: 22°C, sunny" | ||
|
|
||
| def time_function(city: str) -> str: | ||
| """Get current time in a city.""" | ||
| return f"Current time in {city}: 14:30" | ||
|
|
||
| # Create tools | ||
| weather_tool = Tool( | ||
| name="weather", | ||
| description="Get weather information for a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=weather_function, | ||
| ) | ||
|
|
||
| time_tool = Tool( | ||
| name="time", | ||
| description="Get current time in a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=time_function, | ||
| ) | ||
|
|
||
| # Initialize with weather_tool | ||
| component = NvidiaChatGenerator(tools=[weather_tool]) | ||
|
|
||
| # Pass time_tool at runtime - runtime tools should take precedence | ||
| messages = [ChatMessage.from_user("What's the time in Tokyo?")] | ||
| results = await component.run_async(messages, tools=[time_tool]) | ||
|
|
||
| assert len(results["replies"]) == 1 | ||
| message = results["replies"][0] | ||
|
|
||
| # Should use time_tool since it was passed at runtime | ||
| assert message.tool_calls is not None | ||
| tool_call = message.tool_calls[0] | ||
| assert tool_call.tool_name == "time" | ||
| assert tool_call.arguments == {"city": "Tokyo"} | ||
|
|
||
| @pytest.mark.skipif( | ||
| not os.environ.get("NVIDIA_API_KEY", None), | ||
| reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.", | ||
| ) | ||
| @pytest.mark.integration | ||
| @pytest.mark.asyncio | ||
| async def test_integration_mixing_tools_and_toolset_async(self): | ||
| """Test mixing Tool list and Toolset at runtime in async mode.""" | ||
|
|
||
| def weather_function(city: str) -> str: | ||
| """Get weather information for a city.""" | ||
| return f"Weather in {city}: 22°C, sunny" | ||
|
|
||
| def time_function(city: str) -> str: | ||
| """Get current time in a city.""" | ||
| return f"Current time in {city}: 14:30" | ||
|
|
||
| def echo_function(text: str) -> str: | ||
| """Echo a text.""" | ||
| return text | ||
|
|
||
| # Create tools | ||
| weather_tool = Tool( | ||
| name="weather", | ||
| description="Get weather information for a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=weather_function, | ||
| ) | ||
|
|
||
| time_tool = Tool( | ||
| name="time", | ||
| description="Get current time in a city", | ||
| parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, | ||
| function=time_function, | ||
| ) | ||
|
|
||
| echo_tool = Tool( | ||
| name="echo", | ||
| description="Echo a text", | ||
| parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, | ||
| function=echo_function, | ||
| ) | ||
|
|
||
| # Create Toolset with weather and time tools | ||
| toolset = Toolset([weather_tool, time_tool]) | ||
|
|
||
| # Initialize with toolset | ||
| component = NvidiaChatGenerator(tools=toolset) | ||
|
|
||
| # Pass echo_tool as a list at runtime - runtime tools should take precedence | ||
| messages = [ChatMessage.from_user("Echo this: Hello World")] | ||
| results = await component.run_async(messages, tools=[echo_tool]) | ||
|
|
||
| assert len(results["replies"]) == 1 | ||
| message = results["replies"][0] | ||
|
|
||
| # Should use echo_tool since it was passed at runtime | ||
| assert message.tool_calls is not None | ||
| tool_call = message.tool_calls[0] | ||
| assert tool_call.tool_name == "echo" | ||
| assert tool_call.arguments == {"text": "Hello World"} | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this test was intended. We can rather have a test for ser/deser of mixed tool lists like in GoogleGenAiChatGenerator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vblagoje lets remove this test. I still dont see the need of keeping it. The rest looks good.