Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/ollama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.17.1", "ollama>=0.5.0", "pydantic"]
dependencies = ["haystack-ai>=2.19.0", "ollama>=0.5.0", "pydantic"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
)
from haystack.dataclasses.streaming_chunk import ComponentInfo, FinishReason, StreamingChunk, ToolCallDelta
from haystack.tools import (
Tool,
ToolsType,
_check_duplicate_tool_names,
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
)
from haystack.tools.toolset import Toolset
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from pydantic.json_schema import JsonSchemaValue

Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
timeout: int = 120,
keep_alive: Optional[Union[float, str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None,
think: Union[bool, Literal["low", "medium", "high"]] = False,
):
Expand Down Expand Up @@ -248,25 +248,26 @@ def __init__(
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param tools:
A list of `haystack.tools.Tool` or a `haystack.tools.Toolset`. Duplicate tool names raise a `ValueError`.
Not all models support tools. For a list of models compatible with tools, see the
[models page](https://ollama.com/search?c=tools).
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name. Not all models support tools. For a list of models compatible
with tools, see the [models page](https://ollama.com/search?c=tools).
:param response_format:
The format for structured model outputs. The value can be:
- None: No specific structure or format is applied to the response. The response is returned as-is.
- "json": The response is formatted as a JSON object.
- JSON Schema: The response is formatted as a JSON object
that adheres to the specified JSON Schema. (needs Ollama ≥ 0.1.34)
"""
_check_duplicate_tool_names(list(tools or []))
flattened_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(flattened_tools)

self.model = model
self.url = url
self.generation_kwargs = generation_kwargs or {}
self.timeout = timeout
self.keep_alive = keep_alive
self.streaming_callback = streaming_callback
self.tools = tools
self.tools = tools # Store original tools for serialization
self.think = think
self.response_format = response_format

Expand Down Expand Up @@ -470,7 +471,7 @@ def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
*,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> Dict[str, List[ChatMessage]]:
Expand All @@ -485,9 +486,8 @@ def run(
Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param tools:
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance. If set, it will override the `tools` parameter set
during component initialization.
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
If set, it will override the `tools` parameter set during component initialization.
:param streaming_callback:
A callable to receive `StreamingChunk` objects as they
arrive. Supplying a callback (here or in the constructor) switches
Expand All @@ -501,13 +501,15 @@ def run(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
tools = tools or self.tools
_check_duplicate_tool_names(list(tools or []))

# Convert Toolset → list[Tool] for JSON serialization
if isinstance(tools, Toolset):
tools = list(tools)
ollama_tools = [{"type": "function", "function": {**tool.tool_spec}} for tool in tools] if tools else None
tools_to_use = tools or self.tools
flattened_tools = flatten_tools_or_toolsets(tools_to_use)
_check_duplicate_tool_names(flattened_tools)

ollama_tools = (
[{"type": "function", "function": {**tool.tool_spec}} for tool in flattened_tools]
if flattened_tools
else None
)

is_stream = callback is not None

Expand Down Expand Up @@ -535,7 +537,7 @@ async def run_async(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
*,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> Dict[str, List[ChatMessage]]:
Expand All @@ -548,7 +550,7 @@ async def run_async(
Per-call overrides for Ollama inference options.
These are merged on top of the instance-level `generation_kwargs`.
:param tools:
A list of tools or a Toolset for which the model can prepare calls.
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
If set, it will override the `tools` parameter set during component initialization.
:param streaming_callback:
A callable to receive `StreamingChunk` objects as they arrive.
Expand All @@ -560,13 +562,15 @@ async def run_async(
callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)

generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
tools = tools or self.tools
_check_duplicate_tool_names(list(tools or []))

# Convert Toolset → list[Tool] for JSON serialization
if isinstance(tools, Toolset):
tools = list(tools)
ollama_tools = [{"type": "function", "function": {**tool.tool_spec}} for tool in tools] if tools else None
tools_to_use = tools or self.tools
flattened_tools = flatten_tools_or_toolsets(tools_to_use)
_check_duplicate_tool_names(flattened_tools)

ollama_tools = (
[{"type": "function", "function": {**tool.tool_spec}} for tool in flattened_tools]
if flattened_tools
else None
)

is_stream = callback is not None

Expand Down
82 changes: 82 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,58 @@ def test_from_dict(self):
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
}

def test_init_with_mixed_tools(self, tools):
"""Test that the OllamaChatGenerator can be initialized with mixed Tool and Toolset objects."""

@tool
def population(city: Annotated[str, "The city to get the population for"]) -> str:
"""Get the population of a given city."""
return f"The population of {city} is 1 million"

population_toolset = Toolset([population])

# Mix individual Tool and Toolset
mixed_tools = [tools[0], population_toolset]
generator = OllamaChatGenerator(model="qwen3", tools=mixed_tools)

# The tools should be stored as the original ToolsType
assert isinstance(generator.tools, list)
assert len(generator.tools) == 2
# Check that we have a Tool and a Toolset
assert isinstance(generator.tools[0], Tool)
assert isinstance(generator.tools[1], Toolset)
assert generator.tools[0].name == "weather"
# Check that the Toolset contains the population tool
assert len(generator.tools[1]) == 1
assert generator.tools[1][0].name == "population"

def test_run_with_mixed_tools(self, tools):
"""Test that the OllamaChatGenerator can run with mixed Tool and Toolset objects."""

@tool
def population(city: Annotated[str, "The city to get the population for"]) -> str:
"""Get the population of a given city."""
return f"The population of {city} is 1 million"

population_toolset = Toolset([population])

# Mix individual Tool and Toolset
mixed_tools = [tools[0], population_toolset]
generator = OllamaChatGenerator(model="qwen3", tools=mixed_tools)

# Test that the tools are stored as the original ToolsType
tools_list = generator.tools
assert len(tools_list) == 2
# Check that we have a Tool and a Toolset
assert isinstance(tools_list[0], Tool)
assert isinstance(tools_list[1], Toolset)

# Verify tool names
assert tools_list[0].name == "weather"
# Check that the Toolset contains the population tool
assert len(tools_list[1]) == 1
assert tools_list[1][0].name == "population"


class TestOllamaChatGeneratorRun:
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
Expand Down Expand Up @@ -1093,6 +1145,36 @@ def test_live_run_with_tools_and_format(self, tools):
assert isinstance(response_data["population"], (int, float))
assert response_data["capital"].lower() == "paris"

@pytest.mark.parametrize("streaming_callback", [None, Mock()])
def test_live_run_with_mixed_tools(self, tools, streaming_callback):
"""Test live run with mixed Tool and Toolset objects."""

@tool
def population(city: Annotated[str, "The city to get the population for"]) -> str:
"""Get the population of a given city."""
return f"The population of {city} is 1 million"

population_toolset = Toolset([population])

# Mix individual Tool and Toolset
mixed_tools = [tools[0], population_toolset]
component = OllamaChatGenerator(model="qwen3:0.6b", tools=mixed_tools, streaming_callback=streaming_callback)

message = ChatMessage.from_user("What is the weather and population in Paris?")
response = component.run([message])

assert len(response["replies"]) == 1
message = response["replies"][0]

assert message.tool_calls
tool_call = message.tool_call
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name in ["weather", "population"]
assert tool_call.arguments == {"city": "Paris"}

if streaming_callback:
streaming_callback.assert_called()


@pytest.mark.asyncio
@pytest.mark.integration
Expand Down