Skip to content

Commit 8d8e7ca

Browse files
committed
feat(langgraph-py): add reasoning infrastructure
1 parent 6dbb4a8 commit 8d8e7ca

File tree

3 files changed

+211
-29
lines changed

3 files changed

+211
-29
lines changed

python-sdk/ag_ui/integrations/langgraph/agent/agent.py

Lines changed: 139 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
import json
3-
from typing import Optional, List, Any, Union, AsyncGenerator
3+
from typing import Optional, List, Any, Union, AsyncGenerator, Generator
44

55
from fastapi.responses import StreamingResponse
66

@@ -10,10 +10,52 @@
1010
from langchain_core.messages import HumanMessage
1111
from langgraph.types import Command
1212

13-
from .types import State, LangGraphPlatformMessage, MessagesInProgressRecord, SchemaKeys, MessageInProgress, RunMetadata, LangGraphEventTypes, CustomEventNames
14-
from .utils import agui_messages_to_langchain, DEFAULT_SCHEMA_KEYS, filter_object_by_schema_keys, get_stream_payload_input, langchain_messages_to_agui
15-
16-
from ag_ui.core import EventType, CustomEvent, MessagesSnapshotEvent, RawEvent, RunAgentInput, RunErrorEvent, RunFinishedEvent, RunStartedEvent, StateDeltaEvent, StateSnapshotEvent, StepFinishedEvent, StepStartedEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent
13+
from .types import (
14+
State,
15+
LangGraphPlatformMessage,
16+
MessagesInProgressRecord,
17+
SchemaKeys,
18+
MessageInProgress,
19+
RunMetadata,
20+
LangGraphEventTypes,
21+
CustomEventNames,
22+
LangGraphReasoning
23+
)
24+
from .utils import (
25+
agui_messages_to_langchain,
26+
DEFAULT_SCHEMA_KEYS,
27+
filter_object_by_schema_keys,
28+
get_stream_payload_input,
29+
langchain_messages_to_agui,
30+
resolve_reasoning_content,
31+
resolve_message_content
32+
)
33+
34+
from ag_ui.core import (
35+
EventType,
36+
CustomEvent,
37+
MessagesSnapshotEvent,
38+
RawEvent,
39+
RunAgentInput,
40+
RunErrorEvent,
41+
RunFinishedEvent,
42+
RunStartedEvent,
43+
StateDeltaEvent,
44+
StateSnapshotEvent,
45+
StepFinishedEvent,
46+
StepStartedEvent,
47+
TextMessageContentEvent,
48+
TextMessageEndEvent,
49+
TextMessageStartEvent,
50+
ToolCallArgsEvent,
51+
ToolCallEndEvent,
52+
ToolCallStartEvent,
53+
ThinkingTextMessageStartEvent,
54+
ThinkingTextMessageContentEvent,
55+
ThinkingTextMessageEndEvent,
56+
ThinkingStartEvent,
57+
ThinkingEndEvent,
58+
)
1759
from ag_ui.encoder import EventEncoder
1860

1961
ProcessedEvents = Union[
@@ -60,6 +102,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
60102
self.active_run = {
61103
"id": input.run_id,
62104
"thread_id": thread_id,
105+
"thinking_process": None,
63106
}
64107

65108
messages = input.messages or []
@@ -390,10 +433,28 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
390433
is_tool_call_args_event = has_current_stream and current_stream.get("tool_call_id") and tool_call_data and tool_call_data.get("args")
391434
is_tool_call_end_event = has_current_stream and current_stream.get("tool_call_id") and not tool_call_data
392435

393-
is_message_start_event = not has_current_stream and not tool_call_data
394-
is_message_content_event = has_current_stream and not tool_call_data
436+
reasoning_data = resolve_reasoning_content(event["data"]["chunk"]) if event["data"]["chunk"] else None
437+
message_content = resolve_message_content(event["data"]["chunk"].content) if event["data"]["chunk"] and event["data"]["chunk"].content else None
438+
is_message_content_event = tool_call_data is None and message_content
395439
is_message_end_event = has_current_stream and not current_stream.get("tool_call_id") and not is_message_content_event
396440

441+
if reasoning_data:
442+
self.handle_thinking_event(reasoning_data)
443+
return
444+
445+
if reasoning_data is None and self.active_run.get('thinking_process', None) is not None:
446+
yield self._dispatch_event(
447+
ThinkingTextMessageEndEvent(
448+
type=EventType.THINKING_TEXT_MESSAGE_END,
449+
)
450+
)
451+
yield self._dispatch_event(
452+
ThinkingEndEvent(
453+
type=EventType.THINKING_END,
454+
)
455+
)
456+
self.active_run["thinking_process"] = None
457+
397458
if tool_call_used_to_predict_state:
398459
yield self._dispatch_event(
399460
CustomEvent(
@@ -442,27 +503,35 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
442503

443504
if is_tool_call_args_event and should_emit_tool_calls:
444505
yield self._dispatch_event(
445-
ToolCallArgsEvent(type=EventType.TOOL_CALL_ARGS, tool_call_id=current_stream["tool_call_id"], delta=tool_call_data["args"], raw_event=event)
506+
ToolCallArgsEvent(
507+
type=EventType.TOOL_CALL_ARGS,
508+
tool_call_id=current_stream["tool_call_id"],
509+
delta=tool_call_data["args"],
510+
raw_event=event
511+
)
446512
)
447513
return
448514

449-
if is_message_start_event and should_emit_messages:
450-
resolved = self._dispatch_event(
451-
TextMessageStartEvent(
452-
type=EventType.TEXT_MESSAGE_START,
453-
role="assistant",
454-
message_id=event["data"]["chunk"].id,
455-
raw_event=event,
515+
if is_message_content_event and should_emit_messages:
516+
if bool(current_stream and current_stream.get("id")) == False:
517+
yield self._dispatch_event(
518+
TextMessageStartEvent(
519+
type=EventType.TEXT_MESSAGE_START,
520+
role="assistant",
521+
message_id=event["data"]["chunk"].id,
522+
raw_event=event,
523+
)
456524
)
457-
)
458-
if resolved:
459525
self.set_message_in_progress(
460-
self.active_run["id"], MessageInProgress(id=event["data"]["chunk"].id)
526+
self.active_run["id"],
527+
MessageInProgress(
528+
id=event["data"]["chunk"].id,
529+
tool_call_id=None,
530+
tool_call_name=None
531+
)
461532
)
462-
yield resolved
463-
return
533+
current_stream = self.get_message_in_progress(self.active_run["id"])
464534

465-
if is_message_content_event and should_emit_messages:
466535
yield self._dispatch_event(
467536
TextMessageContentEvent(
468537
type=EventType.TEXT_MESSAGE_CONTENT,
@@ -533,6 +602,55 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
533602
CustomEvent(type=EventType.CUSTOM, name=event["name"], value=event["data"], raw_event=event)
534603
)
535604

605+
def handle_thinking_event(self, reasoning_data: LangGraphReasoning) -> Generator[str, Any, str | None]:
606+
if not reasoning_data or "type" not in reasoning_data or "text" not in reasoning_data:
607+
return ""
608+
609+
thinking_step_index = reasoning_data.get("index")
610+
611+
if (self.active_run.get("thinking_process") and
612+
self.active_run["thinking_process"].get("index") and
613+
self.active_run["thinking_process"]["index"] != thinking_step_index):
614+
615+
if self.active_run["thinking_process"].get("type"):
616+
yield self._dispatch_event(
617+
ThinkingTextMessageEndEvent(
618+
type=EventType.THINKING_TEXT_MESSAGE_END,
619+
)
620+
)
621+
yield self._dispatch_event(
622+
ThinkingEndEvent(
623+
type=EventType.THINKING_END,
624+
)
625+
)
626+
self.active_run["thinking_process"] = None
627+
628+
if not self.active_run.get("thinking_process"):
629+
yield self._dispatch_event(
630+
ThinkingStartEvent(
631+
type=EventType.THINKING_START,
632+
)
633+
)
634+
self.active_run["thinking_process"] = {
635+
"index": thinking_step_index
636+
}
637+
638+
if self.active_run["thinking_process"].get("type") != reasoning_data["type"]:
639+
yield self._dispatch_event(
640+
ThinkingTextMessageStartEvent(
641+
type=EventType.THINKING_TEXT_MESSAGE_START,
642+
)
643+
)
644+
self.active_run["thinking_process"]["type"] = reasoning_data["type"]
645+
646+
if self.active_run["thinking_process"].get("type"):
647+
yield self._dispatch_event(
648+
ThinkingTextMessageContentEvent(
649+
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
650+
delta=reasoning_data["text"]
651+
)
652+
)
653+
536654
async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
537655
if not thread_id:
538656
raise ValueError("Missing thread_id in config")

python-sdk/ag_ui/integrations/langgraph/agent/types.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypedDict, Optional, List, Any, Dict, Union
1+
from typing import TypedDict, Optional, List, Any, Dict, Union, Literal
22
from typing_extensions import NotRequired
33
from enum import Enum
44

@@ -28,6 +28,11 @@ class CustomEventNames(str, Enum):
2828
"config": NotRequired[Optional[List[str]]]
2929
})
3030

31+
ThinkingProcess = TypedDict("ThinkingProcess", {
32+
"index": int,
33+
"type": NotRequired[Optional[Literal['text']]],
34+
})
35+
3136
MessageInProgress = TypedDict("MessageInProgress", {
3237
"id": str,
3338
"tool_call_id": NotRequired[Optional[str]],
@@ -41,7 +46,8 @@ class CustomEventNames(str, Enum):
4146
"prev_node_name": NotRequired[Optional[str]],
4247
"exiting_node": NotRequired[bool],
4348
"manually_emitted_state": NotRequired[Optional[State]],
44-
"thread_id": NotRequired[Optional[str]]
49+
"thread_id": NotRequired[Optional[ThinkingProcess]],
50+
"thinking_process": NotRequired[Optional[str]]
4551
})
4652

4753
MessagesInProgressRecord = Dict[str, Optional[MessageInProgress]]
@@ -77,3 +83,9 @@ class LangGraphPlatformActionExecutionMessage(BaseLangGraphPlatformMessage):
7783
"state_key": str,
7884
"tool_argument": str
7985
})
86+
87+
LangGraphReasoning = TypedDict("LangGraphReasoning", {
88+
"type": str,
89+
"text": str,
90+
"index": int
91+
})

python-sdk/ag_ui/integrations/langgraph/agent/utils.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ToolCall as AGUIToolCall,
1212
FunctionCall as AGUIFunctionCall,
1313
)
14-
from .types import State, SchemaKeys
14+
from .types import State, SchemaKeys, LangGraphReasoning
1515

1616
DEFAULT_SCHEMA_KEYS = ["tools"]
1717

@@ -43,7 +43,7 @@ def langchain_messages_to_agui(messages: List[BaseMessage]) -> List[AGUIMessage]
4343
agui_messages.append(AGUIUserMessage(
4444
id=str(message.id),
4545
role="user",
46-
content=stringify_if_needed(message.content),
46+
content=stringify_if_needed(resolve_message_content(message.content)),
4747
name=message.name,
4848
))
4949
elif isinstance(message, AIMessage):
@@ -63,22 +63,22 @@ def langchain_messages_to_agui(messages: List[BaseMessage]) -> List[AGUIMessage]
6363
agui_messages.append(AGUIAssistantMessage(
6464
id=str(message.id),
6565
role="assistant",
66-
content=stringify_if_needed(message.content),
66+
content=stringify_if_needed(resolve_message_content(message.content)),
6767
tool_calls=tool_calls,
6868
name=message.name,
6969
))
7070
elif isinstance(message, SystemMessage):
7171
agui_messages.append(AGUISystemMessage(
7272
id=str(message.id),
7373
role="system",
74-
content=stringify_if_needed(message.content),
74+
content=stringify_if_needed(resolve_message_content(message.content)),
7575
name=message.name,
7676
))
7777
elif isinstance(message, ToolMessage):
7878
agui_messages.append(AGUIToolMessage(
7979
id=str(message.id),
8080
role="tool",
81-
content=stringify_if_needed(message.content),
81+
content=stringify_if_needed(resolve_message_content(message.content)),
8282
tool_call_id=message.tool_call_id,
8383
))
8484
else:
@@ -125,4 +125,56 @@ def agui_messages_to_langchain(messages: List[AGUIMessage]) -> List[BaseMessage]
125125
))
126126
else:
127127
raise ValueError(f"Unsupported message role: {role}")
128-
return langchain_messages
128+
return langchain_messages
129+
130+
def resolve_reasoning_content(chunk: Any) -> LangGraphReasoning | None:
131+
content = chunk.content
132+
if not content:
133+
return None
134+
135+
# Anthropic reasoning response
136+
if isinstance(content, list) and content and content[0]:
137+
if not content[0].get("thinking"):
138+
return None
139+
return LangGraphReasoning(
140+
text=content[0]["thinking"],
141+
type="text",
142+
index=content[0].get("index", 0)
143+
)
144+
145+
# OpenAI reasoning response
146+
if hasattr(chunk, "additional_kwargs"):
147+
reasoning = chunk.additional_kwargs.get("reasoning", {})
148+
summary = reasoning.get("summary", [])
149+
if summary:
150+
data = summary[0]
151+
if not data or not data.get("text"):
152+
return None
153+
return LangGraphReasoning(
154+
type="text",
155+
text=data["text"],
156+
index=data.get("index", 0)
157+
)
158+
159+
try:
160+
parsed = json.loads(content)
161+
return LangGraphReasoning(
162+
type=parsed.get("type", "text"),
163+
text=parsed.get("text", ""),
164+
index=parsed.get("index", 0)
165+
)
166+
except json.JSONDecodeError:
167+
return None
168+
169+
def resolve_message_content(content: Any) -> str | None:
170+
if not content:
171+
return None
172+
173+
if isinstance(content, str):
174+
return content
175+
176+
if isinstance(content, list) and content:
177+
content_text = next((c.get("text") for c in content if isinstance(c, dict) and c.get("type") == "text"), None)
178+
return content_text
179+
180+
return None

0 commit comments

Comments
 (0)