Skip to content

Commit 8098e9c

Browse files
vblagojesjrltstadel
authored
feat: Update tools param to Optional[Union[list[Union[Tool, Toolset]], Toolset]] (#9886)
* Update tools param to Optional[Union[list[Union[Tool, Toolset]], Toolset]] * Exclude tools from schema generation * Different approach * Lint * Use ToolsType * Fixes * Reno note * Update haystack/tools/utils.py Co-authored-by: Sebastian Husch Lee <[email protected]> * Update haystack/tools/serde_utils.py Co-authored-by: tstadel <[email protected]> * Revert "Update haystack/tools/utils.py" This reverts commit ebdec91. * PR feedback * Improve serde tests * Update releasenotes/notes/mixed-tools-toolsets-support-d944c5770e2e6e7b.yaml Co-authored-by: Sebastian Husch Lee <[email protected]> * Pydoc polish * Update FallbackChatGenerator for new ToolsType --------- Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: tstadel <[email protected]>
1 parent e03add6 commit 8098e9c

File tree

19 files changed

+941
-147
lines changed

19 files changed

+941
-147
lines changed

haystack/components/agents/agent.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
from dataclasses import dataclass
7-
from typing import Any, Optional, Union
7+
from typing import Any, Optional, Union, cast
88

99
from haystack import logging, tracing
1010
from haystack.components.generators.chat.types import ChatGenerator
@@ -25,7 +25,14 @@
2525
from haystack.dataclasses import ChatMessage, ChatRole
2626
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, PipelineSnapshot, ToolBreakpoint
2727
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
28-
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
28+
from haystack.tools import (
29+
Tool,
30+
Toolset,
31+
ToolsType,
32+
deserialize_tools_or_toolset_inplace,
33+
flatten_tools_or_toolsets,
34+
serialize_tools_or_toolset,
35+
)
2936
from haystack.utils import _deserialize_value_with_schema
3037
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
3138
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
@@ -97,7 +104,7 @@ def __init__(
97104
self,
98105
*,
99106
chat_generator: ChatGenerator,
100-
tools: Optional[Union[list[Tool], Toolset]] = None,
107+
tools: Optional[ToolsType] = None,
101108
system_prompt: Optional[str] = None,
102109
exit_conditions: Optional[list[str]] = None,
103110
state_schema: Optional[dict[str, Any]] = None,
@@ -110,7 +117,7 @@ def __init__(
110117
Initialize the agent component.
111118
112119
:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
113-
:param tools: List of Tool objects or a Toolset that the agent can use.
120+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset that the agent can use.
114121
:param system_prompt: System prompt for the agent.
115122
:param exit_conditions: List of conditions that will cause the agent to return.
116123
Can include "text" if the agent should return when it generates a message without tool calls,
@@ -134,7 +141,7 @@ def __init__(
134141
"The Agent component requires a chat generator that supports tools."
135142
)
136143

137-
valid_exits = ["text"] + [tool.name for tool in tools or []]
144+
valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
138145
if exit_conditions is None:
139146
exit_conditions = ["text"]
140147
if not all(condition in valid_exits for condition in exit_conditions):
@@ -259,7 +266,7 @@ def _initialize_fresh_execution(
259266
requires_async: bool,
260267
*,
261268
system_prompt: Optional[str] = None,
262-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
269+
tools: Optional[Union[ToolsType, list[str]]] = None,
263270
**kwargs,
264271
) -> _ExecutionContext:
265272
"""
@@ -301,9 +308,7 @@ def _initialize_fresh_execution(
301308
tool_invoker_inputs=tool_invoker_inputs,
302309
)
303310

304-
def _select_tools(
305-
self, tools: Optional[Union[list[Tool], Toolset, list[str]]] = None
306-
) -> Union[list[Tool], Toolset]:
311+
def _select_tools(self, tools: Optional[Union[ToolsType, list[str]]] = None) -> ToolsType:
307312
"""
308313
Select tools for the current run based on the provided tools parameter.
309314
@@ -314,32 +319,40 @@ def _select_tools(
314319
or if any provided tool name is not valid.
315320
:raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings).
316321
"""
317-
selected_tools: Union[list[Tool], Toolset] = self.tools
318-
if isinstance(tools, Toolset) or isinstance(tools, list) and all(isinstance(t, Tool) for t in tools):
319-
selected_tools = tools # type: ignore[assignment] # mypy thinks this could still be list[str]
320-
elif isinstance(tools, list) and all(isinstance(t, str) for t in tools):
322+
if tools is None:
323+
return self.tools
324+
325+
if isinstance(tools, list) and all(isinstance(t, str) for t in tools):
321326
if not self.tools:
322327
raise ValueError("No tools were configured for the Agent at initialization.")
323-
selected_tool_names: list[str] = tools # type: ignore[assignment] # mypy thinks this could still be list[Tool] or Toolset
324-
valid_tool_names = {tool.name for tool in self.tools}
328+
available_tools = flatten_tools_or_toolsets(self.tools)
329+
selected_tool_names = cast(list[str], tools) # mypy thinks this could still be list[Tool] or Toolset
330+
valid_tool_names = {tool.name for tool in available_tools}
325331
invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names}
326332
if invalid_tool_names:
327333
raise ValueError(
328334
f"The following tool names are not valid: {invalid_tool_names}. "
329335
f"Valid tool names are: {valid_tool_names}."
330336
)
331-
selected_tools = [tool for tool in self.tools if tool.name in selected_tool_names]
332-
elif tools is not None:
333-
raise TypeError("tools must be a list of Tool objects, a Toolset, or a list of tool names (strings).")
334-
return selected_tools
337+
return [tool for tool in available_tools if tool.name in selected_tool_names]
338+
339+
if isinstance(tools, Toolset):
340+
return tools
341+
342+
if isinstance(tools, list):
343+
return cast(list[Union[Tool, Toolset]], tools) # mypy can't narrow the Union type from isinstance check
344+
345+
raise TypeError(
346+
"tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)."
347+
)
335348

336349
def _initialize_from_snapshot(
337350
self,
338351
snapshot: AgentSnapshot,
339352
streaming_callback: Optional[StreamingCallbackT],
340353
requires_async: bool,
341354
*,
342-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
355+
tools: Optional[Union[ToolsType, list[str]]] = None,
343356
) -> _ExecutionContext:
344357
"""
345358
Initialize execution context from an AgentSnapshot.
@@ -459,7 +472,7 @@ def run( # noqa: PLR0915
459472
break_point: Optional[AgentBreakpoint] = None,
460473
snapshot: Optional[AgentSnapshot] = None,
461474
system_prompt: Optional[str] = None,
462-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
475+
tools: Optional[Union[ToolsType, list[str]]] = None,
463476
**kwargs: Any,
464477
) -> dict[str, Any]:
465478
"""
@@ -616,7 +629,7 @@ async def run_async(
616629
break_point: Optional[AgentBreakpoint] = None,
617630
snapshot: Optional[AgentSnapshot] = None,
618631
system_prompt: Optional[str] = None,
619-
tools: Optional[Union[list[Tool], Toolset, list[str]]] = None,
632+
tools: Optional[Union[ToolsType, list[str]]] = None,
620633
**kwargs: Any,
621634
) -> dict[str, Any]:
622635
"""

haystack/components/generators/chat/azure.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from haystack.components.generators.chat import OpenAIChatGenerator
1414
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1515
from haystack.tools import (
16-
Tool,
17-
Toolset,
16+
ToolsType,
1817
_check_duplicate_tool_names,
1918
deserialize_tools_or_toolset_inplace,
19+
flatten_tools_or_toolsets,
2020
serialize_tools_or_toolset,
2121
)
2222
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@@ -84,7 +84,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
8484
max_retries: Optional[int] = None,
8585
generation_kwargs: Optional[dict[str, Any]] = None,
8686
default_headers: Optional[dict[str, str]] = None,
87-
tools: Optional[Union[list[Tool], Toolset]] = None,
87+
tools: Optional[ToolsType] = None,
8888
tools_strict: bool = False,
8989
*,
9090
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
@@ -138,8 +138,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
138138
the `response_format` must be a JSON schema and not a Pydantic model.
139139
:param default_headers: Default headers to use for the AzureOpenAI client.
140140
:param tools:
141-
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
142-
list of `Tool` objects or a `Toolset` instance.
141+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
143142
:param tools_strict:
144143
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
145144
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@@ -179,7 +178,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
179178
self.default_headers = default_headers or {}
180179
self.azure_ad_token_provider = azure_ad_token_provider
181180
self.http_client_kwargs = http_client_kwargs
182-
_check_duplicate_tool_names(list(tools or []))
181+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
183182
self.tools = tools
184183
self.tools_strict = tools_strict
185184

haystack/components/generators/chat/fallback.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from __future__ import annotations
66

77
import asyncio
8-
from typing import Any, Union
8+
from typing import Any, Optional, Union
99

1010
from haystack import component, default_from_dict, default_to_dict, logging
1111
from haystack.components.generators.chat.types import ChatGenerator
1212
from haystack.dataclasses import ChatMessage, StreamingCallbackT
13-
from haystack.tools import Tool, Toolset
13+
from haystack.tools import ToolsType
1414
from haystack.utils.deserialization import deserialize_component_inplace
1515

1616
logger = logging.getLogger(__name__)
@@ -86,7 +86,7 @@ def _run_single_sync( # pylint: disable=too-many-positional-arguments
8686
gen: Any,
8787
messages: list[ChatMessage],
8888
generation_kwargs: Union[dict[str, Any], None],
89-
tools: Union[list[Tool], Toolset, None],
89+
tools: Optional[ToolsType],
9090
streaming_callback: Union[StreamingCallbackT, None],
9191
) -> dict[str, Any]:
9292
return gen.run(
@@ -98,7 +98,7 @@ async def _run_single_async( # pylint: disable=too-many-positional-arguments
9898
gen: Any,
9999
messages: list[ChatMessage],
100100
generation_kwargs: Union[dict[str, Any], None],
101-
tools: Union[list[Tool], Toolset, None],
101+
tools: Optional[ToolsType],
102102
streaming_callback: Union[StreamingCallbackT, None],
103103
) -> dict[str, Any]:
104104
if hasattr(gen, "run_async") and callable(gen.run_async):
@@ -121,15 +121,15 @@ def run(
121121
self,
122122
messages: list[ChatMessage],
123123
generation_kwargs: Union[dict[str, Any], None] = None,
124-
tools: Union[list[Tool], Toolset, None] = None,
124+
tools: Optional[ToolsType] = None,
125125
streaming_callback: Union[StreamingCallbackT, None] = None,
126126
) -> dict[str, Any]:
127127
"""
128128
Execute chat generators sequentially until one succeeds.
129129
130130
:param messages: The conversation history as a list of ChatMessage instances.
131131
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
132-
:param tools: Optional Tool instances or Toolset for function calling capabilities.
132+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for function calling capabilities.
133133
:param streaming_callback: Optional callable for handling streaming responses.
134134
:returns: A dictionary with:
135135
- "replies": Generated ChatMessage instances from the first successful generator.
@@ -174,15 +174,15 @@ async def run_async(
174174
self,
175175
messages: list[ChatMessage],
176176
generation_kwargs: Union[dict[str, Any], None] = None,
177-
tools: Union[list[Tool], Toolset, None] = None,
177+
tools: Optional[ToolsType] = None,
178178
streaming_callback: Union[StreamingCallbackT, None] = None,
179179
) -> dict[str, Any]:
180180
"""
181181
Asynchronously execute chat generators sequentially until one succeeds.
182182
183183
:param messages: The conversation history as a list of ChatMessage instances.
184184
:param generation_kwargs: Optional parameters for the chat generator (e.g., temperature, max_tokens).
185-
:param tools: Optional Tool instances or Toolset for function calling capabilities.
185+
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for function calling capabilities.
186186
:param streaming_callback: Optional callable for handling streaming responses.
187187
:returns: A dictionary with:
188188
- "replies": Generated ChatMessage instances from the first successful generator.

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from haystack.dataclasses.streaming_chunk import FinishReason
2222
from haystack.lazy_imports import LazyImport
2323
from haystack.tools import (
24-
Tool,
25-
Toolset,
24+
ToolsType,
2625
_check_duplicate_tool_names,
2726
deserialize_tools_or_toolset_inplace,
27+
flatten_tools_or_toolsets,
2828
serialize_tools_or_toolset,
2929
)
3030
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@@ -93,17 +93,15 @@ def _convert_hfapi_tool_calls(hfapi_tool_calls: Optional[list["ChatCompletionOut
9393
return tool_calls
9494

9595

96-
def _convert_tools_to_hfapi_tools(
97-
tools: Optional[Union[list[Tool], Toolset]],
98-
) -> Optional[list["ChatCompletionInputTool"]]:
96+
def _convert_tools_to_hfapi_tools(tools: Optional[ToolsType]) -> Optional[list["ChatCompletionInputTool"]]:
9997
if not tools:
10098
return None
10199

102100
# huggingface_hub<0.31.0 uses "arguments", huggingface_hub>=0.31.0 uses "parameters"
103101
parameters_name = "arguments" if hasattr(ChatCompletionInputFunctionDefinition, "arguments") else "parameters"
104102

105103
hf_tools = []
106-
for tool in tools:
104+
for tool in flatten_tools_or_toolsets(tools):
107105
hf_tools_args = {"name": tool.name, "description": tool.description, parameters_name: tool.parameters}
108106

109107
hf_tools.append(
@@ -298,7 +296,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
298296
generation_kwargs: Optional[dict[str, Any]] = None,
299297
stop_words: Optional[list[str]] = None,
300298
streaming_callback: Optional[StreamingCallbackT] = None,
301-
tools: Optional[Union[list[Tool], Toolset]] = None,
299+
tools: Optional[ToolsType] = None,
302300
):
303301
"""
304302
Initialize the HuggingFaceAPIChatGenerator instance.
@@ -328,10 +326,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
328326
:param streaming_callback:
329327
An optional callable for handling streaming responses.
330328
:param tools:
331-
A list of tools or a Toolset for which the model can prepare calls.
329+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
332330
The chosen model should support tool/function calling, according to the model card.
333331
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
334-
unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
332+
unexpected behavior.
335333
"""
336334

337335
huggingface_hub_import.check()
@@ -364,7 +362,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
364362

365363
if tools and streaming_callback is not None:
366364
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
367-
_check_duplicate_tool_names(list(tools or []))
365+
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
368366

369367
# handle generation kwargs setup
370368
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
@@ -423,7 +421,7 @@ def run(
423421
self,
424422
messages: list[ChatMessage],
425423
generation_kwargs: Optional[dict[str, Any]] = None,
426-
tools: Optional[Union[list[Tool], Toolset]] = None,
424+
tools: Optional[ToolsType] = None,
427425
streaming_callback: Optional[StreamingCallbackT] = None,
428426
):
429427
"""
@@ -452,7 +450,8 @@ def run(
452450
tools = tools or self.tools
453451
if tools and self.streaming_callback:
454452
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
455-
_check_duplicate_tool_names(list(tools or []))
453+
flat_tools = flatten_tools_or_toolsets(tools)
454+
_check_duplicate_tool_names(flat_tools)
456455

457456
# validate and select the streaming callback
458457
streaming_callback = select_streaming_callback(
@@ -462,9 +461,6 @@ def run(
462461
if streaming_callback:
463462
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
464463

465-
if tools and isinstance(tools, Toolset):
466-
tools = list(tools)
467-
468464
hf_tools = _convert_tools_to_hfapi_tools(tools)
469465

470466
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
@@ -474,7 +470,7 @@ async def run_async(
474470
self,
475471
messages: list[ChatMessage],
476472
generation_kwargs: Optional[dict[str, Any]] = None,
477-
tools: Optional[Union[list[Tool], Toolset]] = None,
473+
tools: Optional[ToolsType] = None,
478474
streaming_callback: Optional[StreamingCallbackT] = None,
479475
):
480476
"""
@@ -506,17 +502,15 @@ async def run_async(
506502
tools = tools or self.tools
507503
if tools and self.streaming_callback:
508504
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
509-
_check_duplicate_tool_names(list(tools or []))
505+
flat_tools = flatten_tools_or_toolsets(tools)
506+
_check_duplicate_tool_names(flat_tools)
510507

511508
# validate and select the streaming callback
512509
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
513510

514511
if streaming_callback:
515512
return await self._run_streaming_async(formatted_messages, generation_kwargs, streaming_callback)
516513

517-
if tools and isinstance(tools, Toolset):
518-
tools = list(tools)
519-
520514
hf_tools = _convert_tools_to_hfapi_tools(tools)
521515

522516
return await self._run_non_streaming_async(formatted_messages, generation_kwargs, hf_tools)

0 commit comments

Comments
 (0)