From 086e302e5b91672984c59d42741533fb5cac8f2f Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 16 Jul 2025 21:16:25 +0100 Subject: [PATCH 1/9] add generation_kwargs to run methods of Agent --- haystack/components/agents/agent.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 5df6549159..51d8008f25 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -210,9 +210,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent": return default_from_dict(cls, data) - def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: + def _prepare_generator_inputs( + self, + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Prepare inputs for the chat generator.""" generator_inputs: Dict[str, Any] = {"tools": self.tools} + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs if streaming_callback is not None: generator_inputs["streaming_callback"] = streaming_callback return generator_inputs @@ -230,7 +236,11 @@ def _create_agent_span(self) -> Any: ) def run( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ) -> Dict[str, Any]: """ Process messages and execute tools until an exit condition is met. @@ -239,6 +249,8 @@ def run( If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -267,7 +279,7 @@ def run( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generator_inputs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", @@ -328,7 +340,11 @@ def run( return result async def run_async( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ) -> Dict[str, Any]: """ Asynchronously process messages and execute tools until the exit condition is met. @@ -341,6 +357,8 @@ async def run_async( :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -369,7 +387,7 @@ async def run_async( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", From 182f2de3527d4db5529af1e2b055a1e972f2911c Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 16 Jul 2025 21:23:27 +0100 Subject: [PATCH 2/9] include test generation_kwargs in agent unit tests --- test/components/agents/test_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 61e45e9610..221f4c2f08 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -746,7 +746,7 @@ def test_run(self, weather_tool): chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) agent.warm_up() - response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) + response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")], generation_kwargs={"temperature": 0.0}) assert isinstance(response, dict) assert "messages" in response From 20995d7979553ac359d5ed45c3e26ac81e3d18c5 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Thu, 17 Jul 2025 08:57:31 +0100 Subject: [PATCH 3/9] fix typo Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/components/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 51d8008f25..8cea8cc16a 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -279,7 +279,7 @@ def run( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generator_inputs=generation_kwargs) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", From b57ead117a6fe70548614c98608f6bc59570c3fb Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 1 Oct 2025 10:01:52 +0100 Subject: [PATCH 4/9] Update haystack/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/components/agents/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 8cea8cc16a..818690e364 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -239,6 +239,7 @@ def run( self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, + *, generation_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> Dict[str, Any]: From 86de835a524cdd5db4e456ce3138e1d5fbd25cb0 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 1 Oct 2025 10:02:00 +0100 Subject: [PATCH 5/9] Update haystack/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/components/agents/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 818690e364..abffa93fc1 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -344,6 +344,7 @@ async def run_async( self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, + *, generation_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> Dict[str, Any]: From fbcbacaeb04013f678ea9931622eb0c5b6d6a58f Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 1 Oct 2025 10:21:17 +0100 Subject: [PATCH 6/9] create separate unit test for generation kwargs with agent --- test/components/agents/test_agent.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 221f4c2f08..c470b692d1 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -746,7 +746,7 @@ def test_run(self, weather_tool): chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) agent.warm_up() - response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")], generation_kwargs={"temperature": 0.0}) + response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) assert isinstance(response, dict) assert "messages" in response @@ -798,6 +798,24 @@ async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async( assert isinstance(result["last_message"], ChatMessage) assert result["messages"][-1] == result["last_message"] + @pytest.mark.asyncio + async def test_generation_kwargs(self): + chat_generator = MockChatGeneratorWithoutRunAsync() + + agent = Agent(chat_generator=chat_generator) + agent.warm_up() + + chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]}) + + await agent.run_async([ChatMessage.from_user("Hello")], generation_kwargs={"temperature": 0.0}) + + expected_messages = [ + ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) + ] + chat_generator.run.assert_called_once_with( + messages=expected_messages, generation_kwargs={"temperature": 0.0}, tools=[] + ) + @pytest.mark.asyncio async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool): chat_generator = MockChatGeneratorWithRunAsync() From 59e95da0ca99e86f6096b930d337fb60ec5f2378 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 1 Oct 2025 10:25:27 +0100 Subject: [PATCH 7/9] fix formatting --- haystack/components/agents/agent.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index abffa93fc1..a8a3d68959 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -213,7 +213,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent": def _prepare_generator_inputs( self, streaming_callback: Optional[StreamingCallbackT] = None, - generation_kwargs: Optional[Dict[str, Any]] = None + generation_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Prepare inputs for the chat generator.""" generator_inputs: Dict[str, Any] = {"tools": self.tools} @@ -241,7 +241,7 @@ def run( streaming_callback: Optional[StreamingCallbackT] = None, *, generation_kwargs: Optional[Dict[str, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> Dict[str, Any]: """ Process messages and execute tools until an exit condition is met. @@ -280,7 +280,9 @@ def run( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) + generator_inputs = self._prepare_generator_inputs( + streaming_callback=streaming_callback, generation_kwargs=generation_kwargs + ) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", @@ -346,7 +348,7 @@ async def run_async( streaming_callback: Optional[StreamingCallbackT] = None, *, generation_kwargs: Optional[Dict[str, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> Dict[str, Any]: """ Asynchronously process messages and execute tools until the exit condition is met. @@ -389,7 +391,9 @@ async def run_async( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) + generator_inputs = self._prepare_generator_inputs( + streaming_callback=streaming_callback, generation_kwargs=generation_kwargs + ) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", From 11d8957469102631e521249d205782e9c10c28f5 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 1 Oct 2025 10:47:36 +0100 Subject: [PATCH 8/9] add back generation_kwargs --- haystack/components/agents/agent.py | 30 +++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index faa019671e..8b06ff2be7 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -259,6 +259,7 @@ def _initialize_fresh_execution( requires_async: bool, *, system_prompt: Optional[str] = None, + generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[list[Tool], Toolset, list[str]]] = None, **kwargs, ) -> _ExecutionContext: @@ -269,6 +270,8 @@ def _initialize_fresh_execution( :param streaming_callback: Optional callback for streaming responses. :param requires_async: Whether the agent run requires asynchronous execution. :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt. + :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will + override the parameters passed during component initialization. :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. When passing tool names, tools are selected from the Agent's originally configured tools. :param kwargs: Additional data to pass to the State used by the Agent. @@ -293,6 +296,8 @@ def _initialize_fresh_execution( if streaming_callback is not None: tool_invoker_inputs["streaming_callback"] = streaming_callback generator_inputs["streaming_callback"] = streaming_callback + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs return _ExecutionContext( state=state, @@ -339,6 +344,7 @@ def _initialize_from_snapshot( streaming_callback: Optional[StreamingCallbackT], requires_async: bool, *, + generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[list[Tool], Toolset, list[str]]] = None, ) -> _ExecutionContext: """ @@ -347,6 +353,8 @@ def _initialize_from_snapshot( :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution. :param streaming_callback: Optional callback for streaming responses. :param requires_async: Whether the agent run requires asynchronous execution. + :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will + override the parameters passed during component initialization. :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. When passing tool names, tools are selected from the Agent's originally configured tools. """ @@ -371,6 +379,8 @@ def _initialize_from_snapshot( if streaming_callback is not None: tool_invoker_inputs["streaming_callback"] = streaming_callback generator_inputs["streaming_callback"] = streaming_callback + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs return _ExecutionContext( state=state, @@ -461,6 +471,7 @@ def run( # noqa: PLR0915 messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, + generation_kwargs: Optional[dict[str, Any]] = None, break_point: Optional[AgentBreakpoint] = None, snapshot: Optional[AgentSnapshot] = None, system_prompt: Optional[str] = None, @@ -473,6 +484,8 @@ def run( # noqa: PLR0915 :param messages: List of Haystack ChatMessage objects to process. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains @@ -503,7 +516,11 @@ def run( # noqa: PLR0915 if snapshot: exe_context = self._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=False, + tools=tools, + generation_kwargs=generation_kwargs, ) else: exe_context = self._initialize_fresh_execution( @@ -512,6 +529,7 @@ def run( # noqa: PLR0915 requires_async=False, system_prompt=system_prompt, tools=tools, + generation_kwargs=generation_kwargs, **kwargs, ) @@ -618,6 +636,7 @@ async def run_async( messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, + generation_kwargs: Optional[dict[str, Any]] = None, break_point: Optional[AgentBreakpoint] = None, snapshot: Optional[AgentSnapshot] = None, system_prompt: Optional[str] = None, @@ -634,6 +653,8 @@ async def run_async( :param messages: List of Haystack ChatMessage objects to process. :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains @@ -663,7 +684,11 @@ async def run_async( if snapshot: exe_context = self._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=True, + tools=tools, + generation_kwargs=generation_kwargs, ) else: exe_context = self._initialize_fresh_execution( @@ -672,6 +697,7 @@ async def run_async( requires_async=True, system_prompt=system_prompt, tools=tools, + generation_kwargs=generation_kwargs, **kwargs, ) From afc7b2709c1cbce64b643d25133d537905b54685 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 1 Oct 2025 10:53:54 +0100 Subject: [PATCH 9/9] add release notes --- .../add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml diff --git a/releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml b/releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml new file mode 100644 index 0000000000..19531b07a0 --- /dev/null +++ b/releasenotes/notes/add-generation-kwargs-to-agent-9600a6c1e6bc069b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Adds `generation_kwargs` to the `Agent` component, allowing for more fine-grained control at run-time over the chat generation.