44
55import inspect
66from dataclasses import dataclass
7- from typing import Any , Optional , Union
7+ from typing import Any , Optional , Union , cast
88
99from haystack import logging , tracing
1010from haystack .components .generators .chat .types import ChatGenerator
2525from haystack .dataclasses import ChatMessage , ChatRole
2626from haystack .dataclasses .breakpoints import AgentBreakpoint , AgentSnapshot , PipelineSnapshot , ToolBreakpoint
2727from haystack .dataclasses .streaming_chunk import StreamingCallbackT , select_streaming_callback
28- from haystack .tools import Tool , Toolset , deserialize_tools_or_toolset_inplace , serialize_tools_or_toolset
28+ from haystack .tools import (
29+ Tool ,
30+ Toolset ,
31+ ToolsType ,
32+ deserialize_tools_or_toolset_inplace ,
33+ flatten_tools_or_toolsets ,
34+ serialize_tools_or_toolset ,
35+ )
2936from haystack .utils import _deserialize_value_with_schema
3037from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
3138from haystack .utils .deserialization import deserialize_chatgenerator_inplace
@@ -97,7 +104,7 @@ def __init__(
97104 self ,
98105 * ,
99106 chat_generator : ChatGenerator ,
100- tools : Optional [Union [ list [ Tool ], Toolset ] ] = None ,
107+ tools : Optional [ToolsType ] = None ,
101108 system_prompt : Optional [str ] = None ,
102109 exit_conditions : Optional [list [str ]] = None ,
103110 state_schema : Optional [dict [str , Any ]] = None ,
@@ -110,7 +117,7 @@ def __init__(
110117 Initialize the agent component.
111118
112119 :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
113- :param tools: List of Tool objects or a Toolset that the agent can use.
120+ :param tools: A list of Tool and/or Toolset objects, or a single Toolset that the agent can use.
114121 :param system_prompt: System prompt for the agent.
115122 :param exit_conditions: List of conditions that will cause the agent to return.
116123 Can include "text" if the agent should return when it generates a message without tool calls,
@@ -134,7 +141,7 @@ def __init__(
134141 "The Agent component requires a chat generator that supports tools."
135142 )
136143
137- valid_exits = ["text" ] + [tool .name for tool in tools or [] ]
144+ valid_exits = ["text" ] + [tool .name for tool in flatten_tools_or_toolsets ( tools ) ]
138145 if exit_conditions is None :
139146 exit_conditions = ["text" ]
140147 if not all (condition in valid_exits for condition in exit_conditions ):
@@ -259,7 +266,7 @@ def _initialize_fresh_execution(
259266 requires_async : bool ,
260267 * ,
261268 system_prompt : Optional [str ] = None ,
262- tools : Optional [Union [list [ Tool ], Toolset , list [str ]]] = None ,
269+ tools : Optional [Union [ToolsType , list [str ]]] = None ,
263270 ** kwargs ,
264271 ) -> _ExecutionContext :
265272 """
@@ -301,9 +308,7 @@ def _initialize_fresh_execution(
301308 tool_invoker_inputs = tool_invoker_inputs ,
302309 )
303310
304- def _select_tools (
305- self , tools : Optional [Union [list [Tool ], Toolset , list [str ]]] = None
306- ) -> Union [list [Tool ], Toolset ]:
311+ def _select_tools (self , tools : Optional [Union [ToolsType , list [str ]]] = None ) -> ToolsType :
307312 """
308313 Select tools for the current run based on the provided tools parameter.
309314
@@ -314,32 +319,40 @@ def _select_tools(
314319 or if any provided tool name is not valid.
315320 :raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings).
316321 """
317- selected_tools : Union [ list [ Tool ], Toolset ] = self . tools
318- if isinstance ( tools , Toolset ) or isinstance ( tools , list ) and all ( isinstance ( t , Tool ) for t in tools ):
319- selected_tools = tools # type: ignore[assignment] # mypy thinks this could still be list[str]
320- elif isinstance (tools , list ) and all (isinstance (t , str ) for t in tools ):
322+ if tools is None :
323+ return self . tools
324+
325+ if isinstance (tools , list ) and all (isinstance (t , str ) for t in tools ):
321326 if not self .tools :
322327 raise ValueError ("No tools were configured for the Agent at initialization." )
323- selected_tool_names : list [str ] = tools # type: ignore[assignment] # mypy thinks this could still be list[Tool] or Toolset
324- valid_tool_names = {tool .name for tool in self .tools }
328+ available_tools = flatten_tools_or_toolsets (self .tools )
329+ selected_tool_names = cast (list [str ], tools ) # mypy thinks this could still be list[Tool] or Toolset
330+ valid_tool_names = {tool .name for tool in available_tools }
325331 invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names }
326332 if invalid_tool_names :
327333 raise ValueError (
328334 f"The following tool names are not valid: { invalid_tool_names } . "
329335 f"Valid tool names are: { valid_tool_names } ."
330336 )
331- selected_tools = [tool for tool in self .tools if tool .name in selected_tool_names ]
332- elif tools is not None :
333- raise TypeError ("tools must be a list of Tool objects, a Toolset, or a list of tool names (strings)." )
334- return selected_tools
337+ return [tool for tool in available_tools if tool .name in selected_tool_names ]
338+
339+ if isinstance (tools , Toolset ):
340+ return tools
341+
342+ if isinstance (tools , list ):
343+ return cast (list [Union [Tool , Toolset ]], tools ) # mypy can't narrow the Union type from isinstance check
344+
345+ raise TypeError (
346+ "tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)."
347+ )
335348
336349 def _initialize_from_snapshot (
337350 self ,
338351 snapshot : AgentSnapshot ,
339352 streaming_callback : Optional [StreamingCallbackT ],
340353 requires_async : bool ,
341354 * ,
342- tools : Optional [Union [list [ Tool ], Toolset , list [str ]]] = None ,
355+ tools : Optional [Union [ToolsType , list [str ]]] = None ,
343356 ) -> _ExecutionContext :
344357 """
345358 Initialize execution context from an AgentSnapshot.
@@ -459,7 +472,7 @@ def run( # noqa: PLR0915
459472 break_point : Optional [AgentBreakpoint ] = None ,
460473 snapshot : Optional [AgentSnapshot ] = None ,
461474 system_prompt : Optional [str ] = None ,
462- tools : Optional [Union [list [ Tool ], Toolset , list [str ]]] = None ,
475+ tools : Optional [Union [ToolsType , list [str ]]] = None ,
463476 ** kwargs : Any ,
464477 ) -> dict [str , Any ]:
465478 """
@@ -616,7 +629,7 @@ async def run_async(
616629 break_point : Optional [AgentBreakpoint ] = None ,
617630 snapshot : Optional [AgentSnapshot ] = None ,
618631 system_prompt : Optional [str ] = None ,
619- tools : Optional [Union [list [ Tool ], Toolset , list [str ]]] = None ,
632+ tools : Optional [Union [ToolsType , list [str ]]] = None ,
620633 ** kwargs : Any ,
621634 ) -> dict [str , Any ]:
622635 """
0 commit comments