@@ -262,14 +262,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
262262 async def prepare_stream (self , input : RunAgentInput , agent_state : State , config : RunnableConfig ):
263263 state_input = input .state or {}
264264 messages = input .messages or []
265- tools = input .tools or []
266265 forwarded_props = input .forwarded_props or {}
267266 thread_id = input .thread_id
268267
269268 state_input ["messages" ] = agent_state .values .get ("messages" , [])
270269 self .active_run ["current_graph_state" ] = agent_state .values .copy ()
271270 langchain_messages = agui_messages_to_langchain (messages )
272- state = self .langgraph_default_merge_state (state_input , langchain_messages , tools )
271+ state = self .langgraph_default_merge_state (state_input , langchain_messages , input )
273272 self .active_run ["current_graph_state" ].update (state )
274273 config ["configurable" ]["thread_id" ] = thread_id
275274 interrupts = agent_state .tasks [0 ].interrupts if agent_state .tasks and len (agent_state .tasks ) > 0 else []
@@ -368,7 +367,7 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
368367 as_node = time_travel_checkpoint .next [0 ] if time_travel_checkpoint .next else "__start__"
369368 )
370369
371- stream_input = self .langgraph_default_merge_state (time_travel_checkpoint .values , [message_checkpoint ], tools )
370+ stream_input = self .langgraph_default_merge_state (time_travel_checkpoint .values , [message_checkpoint ], input )
372371 subgraphs_stream_enabled = input .forwarded_props .get ('stream_subgraphs' ) if input .forwarded_props else False
373372 stream = self .graph .astream_events (
374373 stream_input ,
@@ -415,7 +414,7 @@ def get_schema_keys(self, config) -> SchemaKeys:
415414 "config" : [],
416415 }
417416
418- def langgraph_default_merge_state (self , state : State , messages : List [BaseMessage ], tools : Any ) -> State :
417+ def langgraph_default_merge_state (self , state : State , messages : List [BaseMessage ], input : RunAgentInput ) -> State :
419418 if messages and isinstance (messages [0 ], SystemMessage ):
420419 messages = messages [1 :]
421420
@@ -424,6 +423,7 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
424423
425424 new_messages = [msg for msg in messages if msg .id not in existing_message_ids ]
426425
426+ tools = input .tools or []
427427 tools_as_dicts = []
428428 if tools :
429429 for tool in tools :
@@ -438,6 +438,10 @@ def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage
438438 ** state ,
439439 "messages" : new_messages ,
440440 "tools" : [* state .get ("tools" , []), * tools_as_dicts ],
441+ "ag-ui" : {
442+ "tools" : [* state .get ("tools" , []), * tools_as_dicts ],
443+ "context" : input .context or []
444+ }
441445 }
442446
443447 def get_state_snapshot (self , state : State ) -> State :
0 commit comments