Skip to content

Commit 8888989

Browse files
committed
fix: fix time travel on fastapi implementation
1 parent ab0f0a8 commit 8888989

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
106106
"thinking_process": None,
107107
}
108108

109-
messages = input.messages or []
110109
forwarded_props = input.forwarded_props
111110
node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None
112111

@@ -120,30 +119,13 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
120119

121120
agent_state = await self.graph.aget_state(config)
122121
self.active_run["mode"] = "continue" if thread_id and self.active_run.get("node_name") != "__end__" and self.active_run.get("node_name") else "start"
122+
123123
prepared_stream_response = await self.prepare_stream(input=input, agent_state=agent_state, config=config)
124124

125125
yield self._dispatch_event(
126126
RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
127127
)
128128

129-
langchain_messages = agui_messages_to_langchain(messages)
130-
non_system_messages = [msg for msg in langchain_messages if not isinstance(msg, SystemMessage)]
131-
132-
if len(agent_state.values.get("messages", [])) > len(non_system_messages):
133-
# Find the last user message by working backwards from the last message
134-
last_user_message = None
135-
for i in range(len(langchain_messages) - 1, -1, -1):
136-
if isinstance(langchain_messages[i], HumanMessage):
137-
last_user_message = langchain_messages[i]
138-
break
139-
140-
if last_user_message:
141-
prepared_stream_response = await self.prepare_regenerate_stream(
142-
input=input,
143-
message_checkpoint=last_user_message,
144-
config=config
145-
)
146-
147129
state = prepared_stream_response["state"]
148130
stream = prepared_stream_response["stream"]
149131
config = prepared_stream_response["config"]
@@ -274,14 +256,31 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
274256
thread_id = input.thread_id
275257

276258
state_input["messages"] = agent_state.values.get("messages", [])
277-
self.active_run["current_graph_state"] = agent_state.values
259+
self.active_run["current_graph_state"] = agent_state.values.copy()
278260
langchain_messages = agui_messages_to_langchain(messages)
279261
state = self.langgraph_default_merge_state(state_input, langchain_messages, tools)
280262
self.active_run["current_graph_state"].update(state)
281263
config["configurable"]["thread_id"] = thread_id
282264
interrupts = agent_state.tasks[0].interrupts if agent_state.tasks and len(agent_state.tasks) > 0 else []
283265
has_active_interrupts = len(interrupts) > 0
284266
resume_input = forwarded_props.get('command', {}).get('resume', None)
267+
self.active_run["schema_keys"] = self.get_schema_keys(config)
268+
269+
non_system_messages = [msg for msg in langchain_messages if not isinstance(msg, SystemMessage)]
270+
if len(agent_state.values.get("messages", [])) > len(non_system_messages):
271+
# Find the last user message by working backwards from the last message
272+
last_user_message = None
273+
for i in range(len(langchain_messages) - 1, -1, -1):
274+
if isinstance(langchain_messages[i], HumanMessage):
275+
last_user_message = langchain_messages[i]
276+
break
277+
278+
if last_user_message:
279+
return await self.prepare_regenerate_stream(
280+
input=input,
281+
message_checkpoint=last_user_message,
282+
config=config
283+
)
285284

286285
events_to_dispatch = []
287286
if has_active_interrupts and not resume_input:
@@ -312,8 +311,6 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
312311
if self.active_run["mode"] == "continue":
313312
await self.graph.aupdate_state(config, state, as_node=self.active_run.get("node_name"))
314313

315-
self.active_run["schema_keys"] = self.get_schema_keys(config)
316-
317314
if resume_input:
318315
stream_input = Command(resume=resume_input)
319316
else:

typescript-sdk/integrations/langgraph/src/agent.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ export class LangGraphAgent extends AbstractAgent {
236236
values: { ...stateValuesDiff, messages: [...agentStateMessages, ...stateValuesDiff.messages] },
237237
}
238238
let stateValues = threadState.values;
239+
this.activeRun!.schemaKeys = await this.getSchemaKeys();
239240

240241
if ((agentState.values.messages ?? []).length > messages.filter((m) => m.role !== "system").length) {
241242
let lastUserMessage: LangGraphMessage | null = null;
@@ -268,8 +269,6 @@ export class LangGraphAgent extends AbstractAgent {
268269
});
269270
}
270271

271-
this.activeRun!.schemaKeys = await this.getSchemaKeys();
272-
273272
const payloadInput = getStreamPayloadInput({
274273
mode,
275274
state: stateValues,

0 commit comments

Comments
 (0)