Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 219 additions & 1 deletion integrations/nvidia/tests/test_nvidia_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytz
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.tools import Tool
from haystack.tools import Tool, Toolset
from haystack.utils.auth import Secret
from openai import AsyncOpenAI, OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionMessage
Expand Down Expand Up @@ -329,6 +329,114 @@ def test_live_run_with_json_object(self):
assert isinstance(output["rating"], int)
assert "Inception" in output["title"]

@pytest.mark.skipif(
not os.environ.get("NVIDIA_API_KEY", None),
reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
)
@pytest.mark.integration
def test_integration_mixing_init_and_runtime_tools(self):
"""Test mixing tools from init and runtime by passing tools to both __init__ and run()."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this test was intended. We can rather have a test for ser/deser of mixed tool lists like in GoogleGenAiChatGenerator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vblagoje lets remove this test. I still dont see the need of keeping it. The rest looks good.


def weather_function(city: str) -> str:
"""Get weather information for a city."""
return f"Weather in {city}: 22°C, sunny"

def time_function(city: str) -> str:
"""Get current time in a city."""
return f"Current time in {city}: 14:30"

# Create tools
weather_tool = Tool(
name="weather",
description="Get weather information for a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)

time_tool = Tool(
name="time",
description="Get current time in a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=time_function,
)

# Initialize with weather_tool
component = NvidiaChatGenerator(tools=[weather_tool])

# Pass time_tool at runtime - runtime tools should take precedence
messages = [ChatMessage.from_user("What's the time in Tokyo?")]
results = component.run(messages, tools=[time_tool])

assert len(results["replies"]) == 1
message = results["replies"][0]

# Should use time_tool since it was passed at runtime
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.tool_name == "time"
assert tool_call.arguments == {"city": "Tokyo"}

@pytest.mark.skipif(
not os.environ.get("NVIDIA_API_KEY", None),
reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
)
@pytest.mark.integration
def test_integration_mixing_tools_and_toolset(self):
"""Test mixing Tool list and Toolset at runtime."""

def weather_function(city: str) -> str:
"""Get weather information for a city."""
return f"Weather in {city}: 22°C, sunny"

def time_function(city: str) -> str:
"""Get current time in a city."""
return f"Current time in {city}: 14:30"

def echo_function(text: str) -> str:
"""Echo a text."""
return text

# Create tools
weather_tool = Tool(
name="weather",
description="Get weather information for a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)

time_tool = Tool(
name="time",
description="Get current time in a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=time_function,
)

echo_tool = Tool(
name="echo",
description="Echo a text",
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
function=echo_function,
)

# Create Toolset with weather and time tools
toolset = Toolset([weather_tool, time_tool])

# Initialize with toolset
component = NvidiaChatGenerator(tools=toolset)

# Pass echo_tool as a list at runtime - runtime tools should take precedence
messages = [ChatMessage.from_user("Echo this: Hello World")]
results = component.run(messages, tools=[echo_tool])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about mixing tools and toolsets at runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this one is different: there is no mention of tools and we need to wrap single tool into a list


assert len(results["replies"]) == 1
message = results["replies"][0]

# Should use echo_tool since it was passed at runtime
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.tool_name == "echo"
assert tool_call.arguments == {"text": "Hello World"}


class TestNvidiaChatGeneratorAsync:
def test_init_default_async(self, monkeypatch):
Expand Down Expand Up @@ -439,3 +547,113 @@ async def callback(chunk: StreamingChunk):

assert counter > 1
assert "Paris" in responses

@pytest.mark.skipif(
not os.environ.get("NVIDIA_API_KEY", None),
reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_integration_mixing_init_and_runtime_tools_async(self):
"""Test mixing tools from init and runtime in async mode."""

def weather_function(city: str) -> str:
"""Get weather information for a city."""
return f"Weather in {city}: 22°C, sunny"

def time_function(city: str) -> str:
"""Get current time in a city."""
return f"Current time in {city}: 14:30"

# Create tools
weather_tool = Tool(
name="weather",
description="Get weather information for a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)

time_tool = Tool(
name="time",
description="Get current time in a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=time_function,
)

# Initialize with weather_tool
component = NvidiaChatGenerator(tools=[weather_tool])

# Pass time_tool at runtime - runtime tools should take precedence
messages = [ChatMessage.from_user("What's the time in Tokyo?")]
results = await component.run_async(messages, tools=[time_tool])

assert len(results["replies"]) == 1
message = results["replies"][0]

# Should use time_tool since it was passed at runtime
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.tool_name == "time"
assert tool_call.arguments == {"city": "Tokyo"}

@pytest.mark.skipif(
not os.environ.get("NVIDIA_API_KEY", None),
reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_integration_mixing_tools_and_toolset_async(self):
"""Test mixing Tool list and Toolset at runtime in async mode."""

def weather_function(city: str) -> str:
"""Get weather information for a city."""
return f"Weather in {city}: 22°C, sunny"

def time_function(city: str) -> str:
"""Get current time in a city."""
return f"Current time in {city}: 14:30"

def echo_function(text: str) -> str:
"""Echo a text."""
return text

# Create tools
weather_tool = Tool(
name="weather",
description="Get weather information for a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=weather_function,
)

time_tool = Tool(
name="time",
description="Get current time in a city",
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
function=time_function,
)

echo_tool = Tool(
name="echo",
description="Echo a text",
parameters={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]},
function=echo_function,
)

# Create Toolset with weather and time tools
toolset = Toolset([weather_tool, time_tool])

# Initialize with toolset
component = NvidiaChatGenerator(tools=toolset)

# Pass echo_tool as a list at runtime - runtime tools should take precedence
messages = [ChatMessage.from_user("Echo this: Hello World")]
results = await component.run_async(messages, tools=[echo_tool])

assert len(results["replies"]) == 1
message = results["replies"][0]

# Should use echo_tool since it was passed at runtime
assert message.tool_calls is not None
tool_call = message.tool_calls[0]
assert tool_call.tool_name == "echo"
assert tool_call.arguments == {"text": "Hello World"}
Loading