Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ langchain-experimental = ">=0.0.11"
langchain-google-genai = ">=2.1.9"
langchain-openai = ">=0.0.1"
langgraph = "^0.6.1"
ag-ui-langgraph = { version = "0.0.12a1", extras = ["fastapi"] }
ag-ui-langgraph = { version = "0.0.12a3", extras = ["fastapi"] }
python-dotenv = "^1.0.0"
fastapi = "^0.115.12"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def __init__(self, *, name: str, graph: CompiledStateGraph, description: Optiona
self.messages_in_process: MessagesInProgressRecord = {}
self.active_run: Optional[RunMetadata] = None
self.constant_schema_keys = ['messages', 'tools']
self.active_step = None

def _dispatch_event(self, event: ProcessedEvents) -> str:
if event.type == EventType.RAW:
Expand Down Expand Up @@ -121,9 +120,6 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
node_name_input = forwarded_props.get('node_name', None) if forwarded_props else None

self.active_run["manually_emitted_state"] = None
self.active_run["node_name"] = node_name_input
if self.active_run["node_name"] == "__end__":
self.active_run["node_name"] = None

config = ensure_config(self.config.copy() if self.config else {})
config["configurable"] = {**(config.get('configurable', {})), "thread_id": thread_id}
Expand All @@ -141,10 +137,11 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
yield self._dispatch_event(
RunStartedEvent(type=EventType.RUN_STARTED, thread_id=thread_id, run_id=self.active_run["id"])
)
self.handle_node_change(node_name_input)

# In case of resume (interrupt), re-start resumed step
if resume_input and self.active_run.get("node_name"):
for ev in self.start_step(self.active_run.get("node_name")):
for ev in self.handle_node_change(self.active_run.get("node_name")):
yield ev

state = prepared_stream_response["state"]
Expand Down Expand Up @@ -189,7 +186,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
)

if current_node_name and current_node_name != self.active_run.get("node_name"):
for ev in self.start_step(current_node_name):
for ev in self.handle_node_change(current_node_name):
yield ev

updated_state = self.active_run.get("manually_emitted_state") or current_graph_state
Expand Down Expand Up @@ -236,7 +233,7 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
)

if self.active_run.get("node_name") != node_name:
for ev in self.start_step(node_name):
for ev in self.handle_node_change(node_name):
yield ev

state_values = state.values if state.values else state
Expand All @@ -251,7 +248,8 @@ async def _handle_stream_events(self, input: RunAgentInput) -> AsyncGenerator[st
)
)

yield self.end_step()
for ev in self.handle_node_change(None):
yield ev

yield self._dispatch_event(
RunFinishedEvent(type=EventType.RUN_FINISHED, thread_id=thread_id, run_id=self.active_run["id"])
Expand Down Expand Up @@ -730,34 +728,47 @@ async def get_checkpoint_before_message(self, message_id: str, thread_id: str):

raise ValueError("Message ID not found in history")

def start_step(self, step_name: str):
if self.active_step:
yield self.end_step()
def handle_node_change(self, node_name: Optional[str]):
"""
Centralized method to handle node name changes and step transitions.
Automatically manages step start/end events based on node name changes.
"""
if node_name == "__end__":
node_name = None

if node_name != self.active_run.get("node_name"):
# End current step if we have one
if self.active_run.get("node_name"):
yield self.end_step()

# Start new step if we have a node name
if node_name:
for event in self.start_step(node_name):
yield event

self.active_run["node_name"] = node_name

def start_step(self, step_name: str):
"""Simple step start event dispatcher - node_name management handled by handle_node_change"""
yield self._dispatch_event(
StepStartedEvent(
type=EventType.STEP_STARTED,
step_name=step_name
)
)
self.active_run["node_name"] = step_name
self.active_step = step_name

def end_step(self):
if self.active_step is None:
"""Simple step end event dispatcher - node_name management handled by handle_node_change"""
if not self.active_run.get("node_name"):
raise ValueError("No active step to end")

dispatch = self._dispatch_event(
return self._dispatch_event(
StepFinishedEvent(
type=EventType.STEP_FINISHED,
step_name=self.active_run["node_name"] or self.active_step
step_name=self.active_run["node_name"]
)
)

self.active_run["node_name"] = None
self.active_step = None
return dispatch

# Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility
def get_stream_kwargs(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ag-ui-langgraph"
version = "0.0.12-alpha.1"
version = "0.0.12-alpha.3"
description = "Implementation of the AG-UI protocol for LangGraph."
authors = ["Ran Shem Tov <[email protected]>"]
readme = "README.md"
Expand Down
59 changes: 26 additions & 33 deletions typescript-sdk/integrations/langgraph/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ export class LangGraphAgent extends AbstractAgent {
// @ts-expect-error no need to initialize subscriber right now
subscriber: Subscriber<ProcessedEvents>;
constantSchemaKeys: string[] = DEFAULT_SCHEMA_KEYS;
activeStep?: string;
config: LangGraphAgentConfig;

constructor(config: LangGraphAgentConfig) {
Expand Down Expand Up @@ -235,11 +234,6 @@ export class LangGraphAgent extends AbstractAgent {
this.activeRun!.manuallyEmittedState = null;

const nodeNameInput = forwardedProps?.nodeName;
this.activeRun!.nodeName = nodeNameInput;
if (this.activeRun!.nodeName === "__end__") {
this.activeRun!.nodeName = undefined;
}

const threadId = inputThreadId ?? randomUUID();

if (!this.assistant) {
Expand Down Expand Up @@ -347,6 +341,7 @@ export class LangGraphAgent extends AbstractAgent {
threadId,
runId: input.runId,
});
this.handleNodeChange(nodeNameInput)

interrupts.forEach((interrupt) => {
this.dispatchEvent({
Expand Down Expand Up @@ -400,11 +395,7 @@ export class LangGraphAgent extends AbstractAgent {
threadId,
runId: this.activeRun!.id,
});

// In case of resume (interrupt), re-start resumed step
if (forwardedProps?.command?.resume && this.activeRun!.nodeName) {
this.startStep(this.activeRun!.nodeName);
}
this.handleNodeChange(nodeNameInput)

for await (let streamResponseChunk of streamResponse) {
const subgraphsStreamEnabled = input.forwardedProps?.streamSubgraphs;
Expand Down Expand Up @@ -460,11 +451,7 @@ export class LangGraphAgent extends AbstractAgent {
this.activeRun!.id = metadata.run_id;

if (currentNodeName && currentNodeName !== this.activeRun!.nodeName) {
if (this.activeRun!.nodeName && this.activeRun!.nodeName !== nodeNameInput) {
this.endStep();
}

this.startStep(currentNodeName);
this.handleNodeChange(currentNodeName)
}

shouldExit =
Expand All @@ -482,7 +469,7 @@ export class LangGraphAgent extends AbstractAgent {
// we only want to update the node name under certain conditions
// since we don't need any internal node names to be sent to the frontend
if (this.activeRun!.graphInfo?.["nodes"].some((node) => node.id === currentNodeName)) {
this.activeRun!.nodeName = currentNodeName;
this.handleNodeChange(currentNodeName)
}

updatedState.values = this.activeRun!.manuallyEmittedState ?? latestStateValues;
Expand Down Expand Up @@ -523,6 +510,7 @@ export class LangGraphAgent extends AbstractAgent {
const isEndNode = state.next.length === 0;
const writes = state.metadata?.writes ?? {};

// Initialize a new node name to use in the next if block
let newNodeName = this.activeRun!.nodeName!;

if (!interrupts?.length) {
Expand All @@ -539,12 +527,10 @@ export class LangGraphAgent extends AbstractAgent {
});
});

if (this.activeRun!.nodeName != newNodeName) {
this.endStep();
this.startStep(newNodeName);
}
this.handleNodeChange(newNodeName);
// Immediately turn off new step
this.handleNodeChange(undefined);

this.endStep();
this.dispatchEvent({
type: EventType.STATE_SNAPSHOT,
snapshot: this.getStateSnapshot(state),
Expand Down Expand Up @@ -1017,28 +1003,35 @@ export class LangGraphAgent extends AbstractAgent {
};
}

startStep(nodeName: string) {
if (this.activeStep) {
this.endStep();
handleNodeChange(nodeName: string | undefined) {
if (nodeName === "__end__") {
nodeName = undefined;
}
if (nodeName !== this.activeRun?.nodeName) {
// End current step
if (this.activeRun?.nodeName) {
this.endStep();
}
// If we actually got a node name, start a new step
if (nodeName) {
this.startStep(nodeName);
}
}
this.activeRun!.nodeName = nodeName;
}

startStep(nodeName: string) {
this.dispatchEvent({
type: EventType.STEP_STARTED,
stepName: nodeName,
});
this.activeRun!.nodeName = nodeName;
this.activeStep = nodeName;
}

endStep() {
if (!this.activeStep) {
throw new Error("No active step to end");
}
this.dispatchEvent({
type: EventType.STEP_FINISHED,
stepName: this.activeRun!.nodeName! ?? this.activeStep,
stepName: this.activeRun!.nodeName!,
});
this.activeRun!.nodeName = undefined;
this.activeStep = undefined;
}

async getCheckpointByMessage(
Expand Down
Loading