Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 190 additions & 1 deletion integrations/nvidia/tests/test_nvidia_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytz
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.tools import Tool
from haystack.tools import Tool, Toolset
from haystack.utils.auth import Secret
from openai import AsyncOpenAI, OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionMessage
Expand All @@ -34,6 +34,11 @@ def weather(city: str):
return f"The weather in {city} is sunny and 32°C"


def echo_function(text: str) -> str:
"""Echo a text."""
return text


@pytest.fixture
def tools():
tool_parameters = {
Expand Down Expand Up @@ -329,6 +334,126 @@ def test_live_run_with_json_object(self):
assert isinstance(output["rating"], int)
assert "Inception" in output["title"]

@pytest.mark.skipif(
not os.environ.get("NVIDIA_API_KEY", None),
reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
)
@pytest.mark.integration
def test_integration_mixing_tools_and_toolset(self):
"""Test mixing Tool list and Toolset at runtime."""

def weather_function(city: str) -> str:
"""Get weather information for a city."""
return f"Weather in {city}: 22°C, sunny"

def time_function(city: str) -> str:
"""Get current time in a city."""
return f"Current time in {city}: 14:30"

def echo_function(text: str) -> str:
"""Echo a text."""
return text

# Create tools
weather_tool = Tool(
name="weather",
description="Get weather information for a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)

time_tool = Tool(
name="time",
description="Get current time in a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=time_function,
)

echo_tool = Tool(
name="echo",
description="Echo a text",
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
function=echo_function,
)

# Create Toolset with weather and time tools
toolset = Toolset([weather_tool, time_tool])

# Initialize without tools
component = NvidiaChatGenerator()

# Mix tools and toolset at runtime
messages = [ChatMessage.from_user("What's the weather in Tokyo and echo 'test'")]
results = component.run(messages, tools=[echo_tool, toolset])

assert len(results["replies"]) == 1
message = results["replies"][0]

# Should have access to both echo tool and tools from toolset
assert message.tool_calls is not None
assert len(message.tool_calls) >= 1

# Check that we can use tools from both the list and toolset
tool_names = [call.tool_name for call in message.tool_calls]
assert "echo" in tool_names or "weather" in tool_names

def test_to_dict_with_mixed_tools_and_toolset(self, tools, monkeypatch):
"""Test serialization with a mixed list containing both Tool and Toolset objects."""
monkeypatch.setenv("NVIDIA_API_KEY", "test-api-key")

# Create additional tools for the toolset using module-level function
echo_tool = Tool(
name="echo",
description="Echo a text",
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
function=echo_function,
)

# Create a mixed list: some individual tools + a toolset
toolset = Toolset([echo_tool])
mixed_tools = [*tools, toolset] # List containing both Tool objects and a Toolset

component = NvidiaChatGenerator(model="meta/llama-3.1-8b-instruct", tools=mixed_tools)
data = component.to_dict()

assert data["init_parameters"]["tools"] is not None
assert isinstance(data["init_parameters"]["tools"], list)
assert len(data["init_parameters"]["tools"]) == len(mixed_tools)

# Check that we have both Tool and Toolset in the serialized data
tool_types = [tool["type"] for tool in data["init_parameters"]["tools"]]
assert "haystack.tools.tool.Tool" in tool_types
assert "haystack.tools.toolset.Toolset" in tool_types

def test_from_dict_with_mixed_tools_and_toolset(self, tools, monkeypatch):
"""Test deserialization with a mixed list containing both Tool and Toolset objects."""
monkeypatch.setenv("NVIDIA_API_KEY", "test-api-key")

# Create additional tools for the toolset using module-level function
echo_tool = Tool(
name="echo",
description="Echo a text",
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
function=echo_function,
)

# Create a mixed list: some individual tools + a toolset
toolset = Toolset([echo_tool])
mixed_tools = [*tools, toolset] # List containing both Tool objects and a Toolset

component = NvidiaChatGenerator(model="meta/llama-3.1-8b-instruct", tools=mixed_tools)
data = component.to_dict()

deserialized_component = NvidiaChatGenerator.from_dict(data)

assert isinstance(deserialized_component.tools, list)
assert len(deserialized_component.tools) == len(mixed_tools)

# Check that we have both Tool and Toolset objects in the deserialized list
tool_types = [type(tool).__name__ for tool in deserialized_component.tools]
assert "Tool" in tool_types
assert "Toolset" in tool_types


class TestNvidiaChatGeneratorAsync:
def test_init_default_async(self, monkeypatch):
Expand Down Expand Up @@ -439,3 +564,67 @@ async def callback(chunk: StreamingChunk):

assert counter > 1
assert "Paris" in responses

@pytest.mark.skipif(
not os.environ.get("NVIDIA_API_KEY", None),
reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_integration_mixing_tools_and_toolset_async(self):
"""Test mixing Tool list and Toolset at runtime in async mode."""

def weather_function(city: str) -> str:
"""Get weather information for a city."""
return f"Weather in {city}: 22°C, sunny"

def time_function(city: str) -> str:
"""Get current time in a city."""
return f"Current time in {city}: 14:30"

def echo_function(text: str) -> str:
"""Echo a text."""
return text

# Create tools
weather_tool = Tool(
name="weather",
description="Get weather information for a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)

time_tool = Tool(
name="time",
description="Get current time in a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=time_function,
)

echo_tool = Tool(
name="echo",
description="Echo a text",
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
function=echo_function,
)

# Create Toolset with weather and time tools
toolset = Toolset([weather_tool, time_tool])

# Initialize without tools
component = NvidiaChatGenerator()

# Mix tools and toolset at runtime
messages = [ChatMessage.from_user("What's the weather in Tokyo and echo 'test'")]
results = await component.run_async(messages, tools=[echo_tool, toolset])

assert len(results["replies"]) == 1
message = results["replies"][0]

# Should have access to both echo tool and tools from toolset
assert message.tool_calls is not None
assert len(message.tool_calls) >= 1

# Check that we can use tools from both the list and toolset
tool_names = [call.tool_name for call in message.tool_calls]
assert "echo" in tool_names or "weather" in tool_names
Loading