Skip to content

Commit 97a1cab

Browse files
committed
feat(langgraph-py): several fixes
1 parent 7135b39 commit 97a1cab

File tree

1 file changed

+68
-59
lines changed
  • typescript-sdk/integrations/langgraph/python/ag_ui_langgraph

1 file changed

+68
-59
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,18 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
108108

109109
messages = input.messages or []
110110
forwarded_props = input.forwarded_props
111-
111+
node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None
112+
112113
self.active_run["manually_emitted_state"] = None
113-
self.active_run["node_name"] = forwarded_props.get('node_name', None) if forwarded_props else None
114+
self.active_run["node_name"] = node_name_input
115+
if self.active_run["node_name"] == "__end__":
116+
self.active_run["node_name"] = None
114117

115118
config = ensure_config(self.config.copy() if self.config else {})
116119
config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id}
117120

118121
agent_state = await self.graph.aget_state(config)
122+
self.active_run["mode"] = "continue" if thread_id and self.active_run.get("node_name") != "__end__" and self.active_run.get("node_name") else "start"
119123
prepared_stream_response = await self.prepare_stream(input=input, agent_state=agent_state, config=config)
120124

121125
yield self._dispatch_event(
@@ -151,48 +155,48 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
151155
return
152156

153157
should_exit = False
154-
latest_state_values = state
158+
current_graph_state = state
155159
async for event in stream:
156160
if event["event"] == "error":
157161
yield self._dispatch_event(
158162
RunErrorEvent(type=EventType.RUN_ERROR, message=event["data"]["message"], raw_event=event)
159163
)
160164
break
161165

162-
if event["event"] == "values":
163-
latest_state_values = event["data"]
164-
continue
165-
166-
if event["event"] == "updates":
167-
continue
168-
169166
current_node_name = event.get("metadata", {}).get("langgraph_node")
170167
event_type = event.get("event")
171168
self.active_run["id"] = event.get("run_id")
169+
exiting_node = False
170+
171+
if event_type == "on_chain_end" and isinstance(
172+
event.get("data", {}).get("output"), dict
173+
):
174+
current_graph_state.update(event["data"]["output"])
175+
exiting_node = self.active_run["node_name"] == current_node_name
172176

173177
should_exit = should_exit or (
174178
event_type == "on_custom_event" and
175179
event["name"] == "exit"
176180
)
177181

178182
if current_node_name and current_node_name != self.active_run.get("node_name"):
179-
if self.active_run.get("node_name"):
183+
if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input:
180184
yield self._dispatch_event(
181185
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
182186
)
183-
184-
if current_node_name:
185-
yield self._dispatch_event(
186-
StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
187-
)
188-
self.active_run["node_name"] = current_node_name
187+
self.active_run["node_name"] = None
189188

190-
updated_state = self.active_run.get("manually_emitted_state") or latest_state_values
191-
has_state_diff = updated_state != state
189+
yield self._dispatch_event(
190+
StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
191+
)
192+
self.active_run["node_name"] = current_node_name
192193

193-
if has_state_diff and not self.get_message_in_progress(self.active_run["id"]):
194+
updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
195+
has_state_diff = updated_state != state
196+
if exiting_node or (has_state_diff and not self.get_message_in_progress(self.active_run["id"])):
194197
state = updated_state
195198
self.active_run["prev_node_name"] = self.active_run["node_name"]
199+
current_graph_state.update(updated_state)
196200
yield self._dispatch_event(
197201
StateSnapshotEvent(
198202
type=EventType.STATE_SNAPSHOT,
@@ -208,18 +212,17 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
208212
async for single_event in self._handle_single_event(event, state):
209213
yield single_event
210214

211-
state_after_run = await self.graph.aget_state(config)
212-
tasks = state_after_run.tasks
213-
interrupts = tasks[0].interrupts if tasks and len(tasks) > 0 else []
215+
state = await self.graph.aget_state(config)
214216

215-
if interrupts:
216-
self.active_run["node_name"] = self.active_run["node_name"]
217-
elif "writes" in state_after_run.metadata and state_after_run.metadata["writes"]:
218-
self.active_run["node_name"] = list(state_after_run.metadata["writes"].keys())[0]
219-
elif hasattr(state_after_run, "next") and state_after_run.next and state_after_run.next[0]:
220-
self.active_run["node_name"] = state_after_run.next[0]
221-
else:
222-
self.active_run["node_name"] = "__end__"
217+
tasks = state.tasks if len(state.tasks) > 0 else None
218+
interrupts = tasks[0].interrupts if tasks else []
219+
220+
writes = state.metadata.get("writes", {}) or {}
221+
node_name = self.active_run["node_name"] if interrupts else next(iter(writes), None)
222+
next_nodes = state.next or ()
223+
is_end_node = len(next_nodes) == 0 and not interrupts
224+
225+
node_name = "__end__" if is_end_node else node_name
223226

224227
for interrupt in interrupts:
225228
yield self._dispatch_event(
@@ -231,25 +234,38 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
231234
)
232235
)
233236

237+
if self.active_run.get("node_name") != node_name:
238+
yield self._dispatch_event(
239+
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
240+
)
241+
self.active_run["node_name"] = node_name
242+
yield self._dispatch_event(
243+
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"])
244+
)
245+
246+
# if tasks is None:
234247
yield self._dispatch_event(
235-
StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.get_state_snapshot(state_after_run.values))
248+
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
236249
)
237-
250+
self.active_run["node_name"] = None
251+
252+
state_values = state.values if state.values else state
253+
yield self._dispatch_event(
254+
StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot=self.get_state_snapshot(state_values))
255+
)
256+
238257
yield self._dispatch_event(
239258
MessagesSnapshotEvent(
240259
type=EventType.MESSAGES_SNAPSHOT,
241-
messages=langchain_messages_to_agui(state_after_run.values.get("messages", [])),
260+
messages=langchain_messages_to_agui(state_values.get("messages", [])),
242261
)
243262
)
244263

245-
if self.active_run.get("node_name"):
246-
yield self._dispatch_event(
247-
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
248-
)
249-
250264
yield self._dispatch_event(
251265
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
252266
)
267+
self.active_run = None
268+
253269

254270
async def prepare_stream(self, input: RunAgentInput, agent_state: State, config: RunnableConfig):
255271
state_input = input.state or {}
@@ -259,7 +275,6 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
259275
thread_id = input.thread_id
260276

261277
state_input["messages"] = agent_state.values.get("messages", [])
262-
# TODO: validate if we need current graph state
263278
self.active_run["current_graph_state"] = agent_state.values
264279
langchain_messages = agui_messages_to_langchain(messages)
265280
state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
@@ -295,9 +310,7 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
295310
"events_to_dispatch": events_to_dispatch,
296311
}
297312

298-
mode = "continue" if thread_id and self.active_run.get("node_name") != "__end__" and self.active_run.get("node_name") else "start"
299-
300-
if mode == "continue":
313+
if self.active_run["mode"] == "continue":
301314
await self.graph.aupdate_state(config, state, as_node=self.active_run.get("node_name"))
302315

303316
self.active_run["schema_keys"] = self.get_schema_keys(config)
@@ -306,7 +319,7 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
306319
stream_input = Command(resume=resume_input)
307320
else:
308321
payload_input = get_stream_payload_input(
309-
mode=mode,
322+
mode=self.active_run["mode"],
310323
state=state,
311324
schema_keys=self.active_run["schema_keys"],
312325
)
@@ -466,25 +479,22 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
466479
)
467480

468481
if is_tool_call_end_event:
469-
resolved = self._dispatch_event(
482+
yield self._dispatch_event(
470483
ToolCallEndEvent(type=EventType.TOOL_CALL_END, tool_call_id=current_stream["tool_call_id"], raw_event=event)
471484
)
472-
if resolved:
473-
self.messages_in_process[self.active_run["id"]] = None
474-
yield resolved
485+
self.messages_in_process[self.active_run["id"]] = None
475486
return
476487

488+
477489
if is_message_end_event:
478-
resolved = self._dispatch_event(
490+
yield self._dispatch_event(
479491
TextMessageEndEvent(type=EventType.TEXT_MESSAGE_END, message_id=current_stream["id"], raw_event=event)
480492
)
481-
if resolved:
482-
self.messages_in_process[self.active_run["id"]] = None
483-
yield resolved
493+
self.messages_in_process[self.active_run["id"]] = None
484494
return
485495

486496
if is_tool_call_start_event and should_emit_tool_calls:
487-
resolved = self._dispatch_event(
497+
yield self._dispatch_event(
488498
ToolCallStartEvent(
489499
type=EventType.TOOL_CALL_START,
490500
tool_call_id=tool_call_data["id"],
@@ -493,12 +503,10 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
493503
raw_event=event,
494504
)
495505
)
496-
if resolved:
497-
self.set_message_in_progress(
498-
self.active_run["id"],
499-
MessageInProgress(id=event["data"]["chunk"].id, tool_call_id=tool_call_data["id"], tool_call_name=tool_call_data["name"])
500-
)
501-
yield resolved
506+
self.set_message_in_progress(
507+
self.active_run["id"],
508+
MessageInProgress(id=event["data"]["chunk"].id, tool_call_id=tool_call_data["id"], tool_call_name=tool_call_data["name"])
509+
)
502510
return
503511

504512
if is_tool_call_args_event and should_emit_tool_calls:
@@ -672,3 +680,4 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
672680
return history_list[idx - 1] # return one snapshot *before* the one that includes the message
673681

674682
raise ValueError("Message ID not found in history")
683+

0 commit comments

Comments
 (0)