|
1 | 1 | import uuid
|
2 | 2 | import json
|
3 |
| -from typing import Optional, List, Any, Union, AsyncGenerator |
| 3 | +from typing import Optional, List, Any, Union, AsyncGenerator, Generator |
4 | 4 |
|
5 | 5 | from fastapi.responses import StreamingResponse
|
6 | 6 |
|
|
10 | 10 | from langchain_core.messages import HumanMessage
|
11 | 11 | from langgraph.types import Command
|
12 | 12 |
|
13 |
| -from .types import State, LangGraphPlatformMessage, MessagesInProgressRecord, SchemaKeys, MessageInProgress, RunMetadata, LangGraphEventTypes, CustomEventNames |
14 |
| -from .utils import agui_messages_to_langchain, DEFAULT_SCHEMA_KEYS, filter_object_by_schema_keys, get_stream_payload_input, langchain_messages_to_agui |
15 |
| - |
16 |
| -from ag_ui.core import EventType, CustomEvent, MessagesSnapshotEvent, RawEvent, RunAgentInput, RunErrorEvent, RunFinishedEvent, RunStartedEvent, StateDeltaEvent, StateSnapshotEvent, StepFinishedEvent, StepStartedEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent |
| 13 | +from .types import ( |
| 14 | + State, |
| 15 | + LangGraphPlatformMessage, |
| 16 | + MessagesInProgressRecord, |
| 17 | + SchemaKeys, |
| 18 | + MessageInProgress, |
| 19 | + RunMetadata, |
| 20 | + LangGraphEventTypes, |
| 21 | + CustomEventNames, |
| 22 | + LangGraphReasoning |
| 23 | +) |
| 24 | +from .utils import ( |
| 25 | + agui_messages_to_langchain, |
| 26 | + DEFAULT_SCHEMA_KEYS, |
| 27 | + filter_object_by_schema_keys, |
| 28 | + get_stream_payload_input, |
| 29 | + langchain_messages_to_agui, |
| 30 | + resolve_reasoning_content, |
| 31 | + resolve_message_content |
| 32 | +) |
| 33 | + |
| 34 | +from ag_ui.core import ( |
| 35 | + EventType, |
| 36 | + CustomEvent, |
| 37 | + MessagesSnapshotEvent, |
| 38 | + RawEvent, |
| 39 | + RunAgentInput, |
| 40 | + RunErrorEvent, |
| 41 | + RunFinishedEvent, |
| 42 | + RunStartedEvent, |
| 43 | + StateDeltaEvent, |
| 44 | + StateSnapshotEvent, |
| 45 | + StepFinishedEvent, |
| 46 | + StepStartedEvent, |
| 47 | + TextMessageContentEvent, |
| 48 | + TextMessageEndEvent, |
| 49 | + TextMessageStartEvent, |
| 50 | + ToolCallArgsEvent, |
| 51 | + ToolCallEndEvent, |
| 52 | + ToolCallStartEvent, |
| 53 | + ThinkingTextMessageStartEvent, |
| 54 | + ThinkingTextMessageContentEvent, |
| 55 | + ThinkingTextMessageEndEvent, |
| 56 | + ThinkingStartEvent, |
| 57 | + ThinkingEndEvent, |
| 58 | +) |
17 | 59 | from ag_ui.encoder import EventEncoder
|
18 | 60 |
|
19 | 61 | ProcessedEvents = Union[
|
@@ -60,6 +102,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
|
60 | 102 | self.active_run = {
|
61 | 103 | "id": input.run_id,
|
62 | 104 | "thread_id": thread_id,
|
| 105 | + "thinking_process": None, |
63 | 106 | }
|
64 | 107 |
|
65 | 108 | messages = input.messages or []
|
@@ -390,10 +433,28 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
|
390 | 433 | is_tool_call_args_event = has_current_stream and current_stream.get("tool_call_id") and tool_call_data and tool_call_data.get("args")
|
391 | 434 | is_tool_call_end_event = has_current_stream and current_stream.get("tool_call_id") and not tool_call_data
|
392 | 435 |
|
393 |
| - is_message_start_event = not has_current_stream and not tool_call_data |
394 |
| - is_message_content_event = has_current_stream and not tool_call_data |
| 436 | + reasoning_data = resolve_reasoning_content(event["data"]["chunk"]) if event["data"]["chunk"] else None |
| 437 | + message_content = resolve_message_content(event["data"]["chunk"].content) if event["data"]["chunk"] and event["data"]["chunk"].content else None |
| 438 | + is_message_content_event = tool_call_data is None and message_content |
395 | 439 | is_message_end_event = has_current_stream and not current_stream.get("tool_call_id") and not is_message_content_event
|
396 | 440 |
|
| 441 | + if reasoning_data: |
| 442 | + self.handle_thinking_event(reasoning_data) |
| 443 | + return |
| 444 | + |
| 445 | + if reasoning_data is None and self.active_run.get('thinking_process', None) is not None: |
| 446 | + yield self._dispatch_event( |
| 447 | + ThinkingTextMessageEndEvent( |
| 448 | + type=EventType.THINKING_TEXT_MESSAGE_END, |
| 449 | + ) |
| 450 | + ) |
| 451 | + yield self._dispatch_event( |
| 452 | + ThinkingEndEvent( |
| 453 | + type=EventType.THINKING_END, |
| 454 | + ) |
| 455 | + ) |
| 456 | + self.active_run["thinking_process"] = None |
| 457 | + |
397 | 458 | if tool_call_used_to_predict_state:
|
398 | 459 | yield self._dispatch_event(
|
399 | 460 | CustomEvent(
|
@@ -442,27 +503,35 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
|
442 | 503 |
|
443 | 504 | if is_tool_call_args_event and should_emit_tool_calls:
|
444 | 505 | yield self._dispatch_event(
|
445 |
| - ToolCallArgsEvent(type=EventType.TOOL_CALL_ARGS, tool_call_id=current_stream["tool_call_id"], delta=tool_call_data["args"], raw_event=event) |
| 506 | + ToolCallArgsEvent( |
| 507 | + type=EventType.TOOL_CALL_ARGS, |
| 508 | + tool_call_id=current_stream["tool_call_id"], |
| 509 | + delta=tool_call_data["args"], |
| 510 | + raw_event=event |
| 511 | + ) |
446 | 512 | )
|
447 | 513 | return
|
448 | 514 |
|
449 |
| - if is_message_start_event and should_emit_messages: |
450 |
| - resolved = self._dispatch_event( |
451 |
| - TextMessageStartEvent( |
452 |
| - type=EventType.TEXT_MESSAGE_START, |
453 |
| - role="assistant", |
454 |
| - message_id=event["data"]["chunk"].id, |
455 |
| - raw_event=event, |
| 515 | + if is_message_content_event and should_emit_messages: |
| 516 | + if bool(current_stream and current_stream.get("id")) == False: |
| 517 | + yield self._dispatch_event( |
| 518 | + TextMessageStartEvent( |
| 519 | + type=EventType.TEXT_MESSAGE_START, |
| 520 | + role="assistant", |
| 521 | + message_id=event["data"]["chunk"].id, |
| 522 | + raw_event=event, |
| 523 | + ) |
456 | 524 | )
|
457 |
| - ) |
458 |
| - if resolved: |
459 | 525 | self.set_message_in_progress(
|
460 |
| - self.active_run["id"], MessageInProgress(id=event["data"]["chunk"].id) |
| 526 | + self.active_run["id"], |
| 527 | + MessageInProgress( |
| 528 | + id=event["data"]["chunk"].id, |
| 529 | + tool_call_id=None, |
| 530 | + tool_call_name=None |
| 531 | + ) |
461 | 532 | )
|
462 |
| - yield resolved |
463 |
| - return |
| 533 | + current_stream = self.get_message_in_progress(self.active_run["id"]) |
464 | 534 |
|
465 |
| - if is_message_content_event and should_emit_messages: |
466 | 535 | yield self._dispatch_event(
|
467 | 536 | TextMessageContentEvent(
|
468 | 537 | type=EventType.TEXT_MESSAGE_CONTENT,
|
@@ -533,6 +602,55 @@ async def _handle_single_event(self, event: Any, state: State) -> AsyncGenerator
|
533 | 602 | CustomEvent(type=EventType.CUSTOM, name=event["name"], value=event["data"], raw_event=event)
|
534 | 603 | )
|
535 | 604 |
|
| 605 | + def handle_thinking_event(self, reasoning_data: LangGraphReasoning) -> Generator[str, Any, str | None]: |
| 606 | + if not reasoning_data or "type" not in reasoning_data or "text" not in reasoning_data: |
| 607 | + return "" |
| 608 | + |
| 609 | + thinking_step_index = reasoning_data.get("index") |
| 610 | + |
| 611 | + if (self.active_run.get("thinking_process") and |
| 612 | + self.active_run["thinking_process"].get("index") and |
| 613 | + self.active_run["thinking_process"]["index"] != thinking_step_index): |
| 614 | + |
| 615 | + if self.active_run["thinking_process"].get("type"): |
| 616 | + yield self._dispatch_event( |
| 617 | + ThinkingTextMessageEndEvent( |
| 618 | + type=EventType.THINKING_TEXT_MESSAGE_END, |
| 619 | + ) |
| 620 | + ) |
| 621 | + yield self._dispatch_event( |
| 622 | + ThinkingEndEvent( |
| 623 | + type=EventType.THINKING_END, |
| 624 | + ) |
| 625 | + ) |
| 626 | + self.active_run["thinking_process"] = None |
| 627 | + |
| 628 | + if not self.active_run.get("thinking_process"): |
| 629 | + yield self._dispatch_event( |
| 630 | + ThinkingStartEvent( |
| 631 | + type=EventType.THINKING_START, |
| 632 | + ) |
| 633 | + ) |
| 634 | + self.active_run["thinking_process"] = { |
| 635 | + "index": thinking_step_index |
| 636 | + } |
| 637 | + |
| 638 | + if self.active_run["thinking_process"].get("type") != reasoning_data["type"]: |
| 639 | + yield self._dispatch_event( |
| 640 | + ThinkingTextMessageStartEvent( |
| 641 | + type=EventType.THINKING_TEXT_MESSAGE_START, |
| 642 | + ) |
| 643 | + ) |
| 644 | + self.active_run["thinking_process"]["type"] = reasoning_data["type"] |
| 645 | + |
| 646 | + if self.active_run["thinking_process"].get("type"): |
| 647 | + yield self._dispatch_event( |
| 648 | + ThinkingTextMessageContentEvent( |
| 649 | + type=EventType.THINKING_TEXT_MESSAGE_CONTENT, |
| 650 | + delta=reasoning_data["text"] |
| 651 | + ) |
| 652 | + ) |
| 653 | + |
536 | 654 | async def get_checkpoint_before_message(self, message_id: str, thread_id: str):
|
537 | 655 | if not thread_id:
|
538 | 656 | raise ValueError("Missing thread_id in config")
|
|
0 commit comments