diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index e90ff3eb18..b8b7739eea 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -268,6 +268,7 @@ def _initialize_fresh_execution( requires_async: bool, *, system_prompt: Optional[str] = None, + generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[ToolsType, list[str]]] = None, **kwargs, ) -> _ExecutionContext: @@ -278,6 +279,8 @@ def _initialize_fresh_execution( :param streaming_callback: Optional callback for streaming responses. :param requires_async: Whether the agent run requires asynchronous execution. :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt. + :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will + override the parameters passed during component initialization. :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. When passing tool names, tools are selected from the Agent's originally configured tools. :param kwargs: Additional data to pass to the State used by the Agent. @@ -302,6 +305,8 @@ def _initialize_fresh_execution( if streaming_callback is not None: tool_invoker_inputs["streaming_callback"] = streaming_callback generator_inputs["streaming_callback"] = streaming_callback + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs return _ExecutionContext( state=state, @@ -354,6 +359,7 @@ def _initialize_from_snapshot( streaming_callback: Optional[StreamingCallbackT], requires_async: bool, *, + generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[ToolsType, list[str]]] = None, ) -> _ExecutionContext: """ @@ -362,6 +368,8 @@ def _initialize_from_snapshot( :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution. :param streaming_callback: Optional callback for streaming responses. :param requires_async: Whether the agent run requires asynchronous execution. + :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will + override the parameters passed during component initialization. :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. When passing tool names, tools are selected from the Agent's originally configured tools. """ @@ -386,6 +394,8 @@ def _initialize_from_snapshot( if streaming_callback is not None: tool_invoker_inputs["streaming_callback"] = streaming_callback generator_inputs["streaming_callback"] = streaming_callback + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs return _ExecutionContext( state=state, @@ -471,6 +481,7 @@ def run( # noqa: PLR0915 messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, + generation_kwargs: Optional[dict[str, Any]] = None, break_point: Optional[AgentBreakpoint] = None, snapshot: Optional[AgentSnapshot] = None, system_prompt: Optional[str] = None, @@ -483,6 +494,8 @@ def run( # noqa: PLR0915 :param messages: List of Haystack ChatMessage objects to process. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains @@ -513,7 +526,11 @@ def run( # noqa: PLR0915 if snapshot: exe_context = self._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=False, + tools=tools, + generation_kwargs=generation_kwargs, ) else: exe_context = self._initialize_fresh_execution( @@ -522,6 +539,7 @@ def run( # noqa: PLR0915 requires_async=False, system_prompt=system_prompt, tools=tools, + generation_kwargs=generation_kwargs, **kwargs, ) @@ -628,6 +646,7 @@ async def run_async( messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, + generation_kwargs: Optional[dict[str, Any]] = None, break_point: Optional[AgentBreakpoint] = None, snapshot: Optional[AgentSnapshot] = None, system_prompt: Optional[str] = None, @@ -644,6 +663,8 @@ async def run_async( :param messages: List of Haystack ChatMessage objects to process. :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains @@ -673,7 +694,11 @@ async def run_async( if snapshot: exe_context = self._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=True, + tools=tools, + generation_kwargs=generation_kwargs, ) else: exe_context = self._initialize_fresh_execution( @@ -682,6 +707,7 @@ async def run_async( requires_async=True, system_prompt=system_prompt, tools=tools, + generation_kwargs=generation_kwargs, **kwargs, ) diff --git a/releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml b/releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml new file mode 100644 index 0000000000..19531b07a0 --- /dev/null +++ b/releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Adds `generation_kwargs` to the `Agent` component, allowing for more fine-grained control at run-time over the chat generation. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index bec1406e79..24b9773043 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -884,6 +884,24 @@ async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async( assert isinstance(result["last_message"], ChatMessage) assert result["messages"][-1] == result["last_message"] + @pytest.mark.asyncio + async def test_generation_kwargs(self): + chat_generator = MockChatGeneratorWithoutRunAsync() + + agent = Agent(chat_generator=chat_generator) + agent.warm_up() + + chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]}) + + await agent.run_async([ChatMessage.from_user("Hello")], generation_kwargs={"temperature": 0.0}) + + expected_messages = [ + ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) + ] + chat_generator.run.assert_called_once_with( + messages=expected_messages, generation_kwargs={"temperature": 0.0}, tools=[] + ) + @pytest.mark.asyncio async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool): chat_generator = MockChatGenerator()