Skip to content

Commit 2a87e8e

Browse files
committed
feat: provide backwards compatibility for the new runtime context property
1 parent 59a09a5 commit 2a87e8e

File tree

1 file changed

+54
-10
lines changed
  • typescript-sdk/integrations/langgraph/python/ag_ui_langgraph

1 file changed

+54
-10
lines changed

typescript-sdk/integrations/langgraph/python/ag_ui_langgraph/agent.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import uuid
22
import json
3-
from typing import Optional, List, Any, Union, AsyncGenerator, Generator
3+
from typing import Optional, List, Any, Union, AsyncGenerator, Generator, Literal, Dict
4+
import inspect
45

56
from langgraph.graph.state import CompiledStateGraph
67
from langchain.schema import BaseMessage, SystemMessage
@@ -335,13 +336,15 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
335336

336337
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
337338

338-
stream = self.graph.astream_events(
339-
stream_input,
339+
kwargs = self.get_stream_kwargs(
340+
input=stream_input,
340341
config=config,
341-
subgraps=bool(subgraphs_stream_enabled),
342-
version="v2"
342+
subgraphs=bool(subgraphs_stream_enabled),
343+
version="v2",
343344
)
344345

346+
stream = self.graph.astream_events(**kwargs)
347+
345348
return {
346349
"stream": stream,
347350
"state": state,
@@ -369,12 +372,14 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
369372

370373
stream_input = self.langgraph_default_merge_state(time_travel_checkpoint.values, [message_checkpoint], input)
371374
subgraphs_stream_enabled = input.forwarded_props.get('stream_subgraphs') if input.forwarded_props else False
372-
stream = self.graph.astream_events(
373-
stream_input,
374-
fork,
375-
subgraps=bool(subgraphs_stream_enabled),
376-
version="v2"
375+
376+
kwargs = self.get_stream_kwargs(
377+
input=stream_input,
378+
fork=fork,
379+
subgraphs=bool(subgraphs_stream_enabled),
380+
version="v2",
377381
)
382+
stream = self.graph.astream_events(**kwargs)
378383

379384
return {
380385
"stream": stream,
@@ -397,21 +402,25 @@ def get_schema_keys(self, config) -> SchemaKeys:
397402
input_schema = self.graph.get_input_jsonschema(config)
398403
output_schema = self.graph.get_output_jsonschema(config)
399404
config_schema = self.graph.config_schema().schema()
405+
context_schema = self.graph.context_schema().schema()
400406

401407
input_schema_keys = list(input_schema["properties"].keys()) if "properties" in input_schema else []
402408
output_schema_keys = list(output_schema["properties"].keys()) if "properties" in output_schema else []
403409
config_schema_keys = list(config_schema["properties"].keys()) if "properties" in config_schema else []
410+
context_schema_keys = list(context_schema["properties"].keys()) if "properties" in context_schema else []
404411

405412
return {
406413
"input": [*input_schema_keys, *self.constant_schema_keys],
407414
"output": [*output_schema_keys, *self.constant_schema_keys],
408415
"config": config_schema_keys,
416+
"context": context_schema_keys,
409417
}
410418
except Exception:
411419
return {
412420
"input": self.constant_schema_keys,
413421
"output": self.constant_schema_keys,
414422
"config": [],
423+
"context": {},
415424
}
416425

417426
def langgraph_default_merge_state(self, state: State, messages: List[BaseMessage], input: RunAgentInput) -> State:
@@ -744,3 +753,38 @@ def end_step(self):
744753
self.active_run["node_name"] = None
745754
self.active_step = None
746755
return dispatch
756+
757+
# Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility
758+
def get_stream_kwargs(
759+
self,
760+
input: Any,
761+
subgraphs: bool = False,
762+
version: Literal["v1", "v2"] = "v2",
763+
config: Optional[RunnableConfig] = None,
764+
context: Optional[Dict[str, Any]] = None,
765+
fork: Optional[Any] = None,
766+
):
767+
kwargs = dict(
768+
input=input,
769+
subgraphs=subgraphs,
770+
version=version,
771+
)
772+
773+
# Only add context if supported
774+
sig = inspect.signature(self.graph.astream_events)
775+
if 'context' in sig.parameters:
776+
base_context = {}
777+
if isinstance(config, dict) and 'configurable' in config and isinstance(config['configurable'], dict):
778+
base_context.update(config['configurable'])
779+
if context: # context might be None or {}
780+
base_context.update(context)
781+
if base_context: # only add if there's something to pass
782+
kwargs['context'] = base_context
783+
784+
if config:
785+
kwargs['config'] = config
786+
787+
if fork:
788+
kwargs.update(fork)
789+
790+
return kwargs

0 commit comments

Comments
 (0)