Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import uuid
import json
from typing import Optional, List, Any, Union, AsyncGenerator, Generator
from dataclasses import is_dataclass, asdict
from datetime import date, datetime

from langgraph.graph.state import CompiledStateGraph
from langchain.schema import BaseMessage, SystemMessage
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
self.messages_in_process: MessagesInProgressRecord = {}
self.active_run: Optional[RunMetadata] = None
self.constant_schema_keys = ['messages', 'tools']
self.active_step = None

def _dispatch_event(self, event: ProcessedEvents) -> str:
return event # Fallback if no encoder
Expand Down Expand Up @@ -135,9 +138,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st

# In case of resume (interrupt), re-start resumed step
if resume_input and self.active_run.get("node_name"):
yield self._dispatch_event(
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run.get("node_name"))
)
for ev in self.start_step(self.active_run.get("node_name")):
yield ev

state = prepared_stream_response["state"]
stream = prepared_stream_response["stream"]
Expand All @@ -151,6 +153,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st

should_exit = False
current_graph_state = state

async for event in stream:
if event["event"] == "error":
yield self._dispatch_event(
Expand All @@ -175,16 +178,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
)

if current_node_name and current_node_name != self.active_run.get("node_name"):
if self.active_run["node_name"] and self.active_run["node_name"] != node_name_input:
yield self._dispatch_event(
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
)
self.active_run["node_name"] = None

yield self._dispatch_event(
StepStartedEvent(type=EventType.STEP_STARTED, step_name=current_node_name)
)
self.active_run["node_name"] = current_node_name
for ev in self.start_step(current_node_name):
yield ev

updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
has_state_diff = updated_state != state
Expand Down Expand Up @@ -224,19 +219,14 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
CustomEvent(
type=EventType.CUSTOM,
name=LangGraphEventTypes.OnInterrupt.value,
value=json.dumps(interrupt.value) if not isinstance(interrupt.value, str) else interrupt.value,
value=json.dumps(interrupt.value, default=make_json_safe) if not isinstance(interrupt.value, str) else interrupt.value,
raw_event=interrupt,
)
)

if self.active_run.get("node_name") != node_name:
yield self._dispatch_event(
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
)
self.active_run["node_name"] = node_name
yield self._dispatch_event(
StepStartedEvent(type=EventType.STEP_STARTED, step_name=self.active_run["node_name"])
)
for ev in self.start_step(node_name):
yield ev

state_values = state.values if state.values else state
yield self._dispatch_event(
Expand All @@ -250,10 +240,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
)
)

yield self._dispatch_event(
StepFinishedEvent(type=EventType.STEP_FINISHED, step_name=self.active_run["node_name"])
)
self.active_run["node_name"] = None
yield self.end_step()

yield self._dispatch_event(
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
Expand Down Expand Up @@ -700,3 +687,43 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):

raise ValueError("Message ID not found in history")

def start_step(self, step_name: str):
if self.active_step:
yield self.end_step()

yield self._dispatch_event(
StepStartedEvent(
type=EventType.STEP_STARTED,
step_name=step_name
)
)
self.active_run["node_name"] = step_name
self.active_step = step_name

def end_step(self):
if self.active_step is None:
raise ValueError("No active step to end")

dispatch = self._dispatch_event(
StepFinishedEvent(
type=EventType.STEP_FINISHED,
step_name=self.active_run["node_name"]
)
)

self.active_run["node_name"] = None
self.active_step = None
return dispatch

def make_json_safe(o):
if is_dataclass(o): # dataclasses like Flight(...)
return asdict(o)
if hasattr(o, "model_dump"): # pydantic v2
return o.model_dump()
if hasattr(o, "dict"): # pydantic v1
return o.dict()
if hasattr(o, "__dict__"): # plain objects
return vars(o)
if isinstance(o, (datetime, date)):
return o.isoformat()
return str(o) # last resort
13 changes: 10 additions & 3 deletions typescript-sdk/integrations/langgraph/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ export class LangGraphAgent extends AbstractAgent {
// @ts-expect-error no need to initialize subscriber right now
subscriber: Subscriber<ProcessedEvents>;
constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS;
activeStep?: string;

constructor(config: LangGraphAgentConfig) {
super(config);
Expand Down Expand Up @@ -383,7 +384,9 @@ export class LangGraphAgent extends AbstractAgent {
break;
}

if (streamResponseChunk.event === "updates") continue;
if (streamResponseChunk.event === "updates") {
continue;
}

if (streamResponseChunk.event === "values") {
latestStateValues = chunk.data;
Expand Down Expand Up @@ -467,7 +470,6 @@ export class LangGraphAgent extends AbstractAgent {
newNodeName = isEndNode ? '__end__' : (state.next[0] ?? Object.keys(writes)[0]);
}


interrupts.forEach((interrupt) => {
this.dispatchEvent({
type: EventType.CUSTOM,
Expand Down Expand Up @@ -944,22 +946,27 @@ export class LangGraphAgent extends AbstractAgent {
}

startStep(nodeName: string) {
if (this.activeStep) {
this.endStep()
}
this.dispatchEvent({
type: EventType.STEP_STARTED,
stepName: nodeName,
});
this.activeRun!.nodeName = nodeName;
this.activeStep = nodeName;
}

endStep() {
if (!this.activeRun!.nodeName) {
if (!this.activeStep) {
throw new Error("No active step to end");
}
this.dispatchEvent({
type: EventType.STEP_FINISHED,
stepName: this.activeRun!.nodeName!,
});
this.activeRun!.nodeName = undefined;
this.activeStep = undefined;
}

async getCheckpointByMessage(
Expand Down
Loading