1
1
import uuid
2
2
import json
3
- from typing import Optional , List , Any , Union , AsyncGenerator , Generator
3
+ from typing import Optional , List , Any , Union , AsyncGenerator , Generator , Literal , Dict
4
+ import inspect
4
5
5
6
from langgraph .graph .state import CompiledStateGraph
6
7
from langchain .schema import BaseMessage , SystemMessage
@@ -335,13 +336,15 @@ async def prepare_stream(self, input: RunAgentInput, agent_state: State, config:
335
336
336
337
subgraphs_stream_enabled = input .forwarded_props .get ('stream_subgraphs' ) if input .forwarded_props else False
337
338
338
- stream = self .graph . astream_events (
339
- stream_input ,
339
+ kwargs = self .get_stream_kwargs (
340
+ input = stream_input ,
340
341
config = config ,
341
- subgraps = bool (subgraphs_stream_enabled ),
342
- version = "v2"
342
+ subgraphs = bool (subgraphs_stream_enabled ),
343
+ version = "v2" ,
343
344
)
344
345
346
+ stream = self .graph .astream_events (** kwargs )
347
+
345
348
return {
346
349
"stream" : stream ,
347
350
"state" : state ,
@@ -369,12 +372,14 @@ async def prepare_regenerate_stream( # pylint: disable=too-many-arguments
369
372
370
373
stream_input = self .langgraph_default_merge_state (time_travel_checkpoint .values , [message_checkpoint ], input )
371
374
subgraphs_stream_enabled = input .forwarded_props .get ('stream_subgraphs' ) if input .forwarded_props else False
372
- stream = self .graph .astream_events (
373
- stream_input ,
374
- fork ,
375
- subgraps = bool (subgraphs_stream_enabled ),
376
- version = "v2"
375
+
376
+ kwargs = self .get_stream_kwargs (
377
+ input = stream_input ,
378
+ fork = fork ,
379
+ subgraphs = bool (subgraphs_stream_enabled ),
380
+ version = "v2" ,
377
381
)
382
+ stream = self .graph .astream_events (** kwargs )
378
383
379
384
return {
380
385
"stream" : stream ,
@@ -401,17 +406,25 @@ def get_schema_keys(self, config) -> SchemaKeys:
401
406
input_schema_keys = list (input_schema ["properties" ].keys ()) if "properties" in input_schema else []
402
407
output_schema_keys = list (output_schema ["properties" ].keys ()) if "properties" in output_schema else []
403
408
config_schema_keys = list (config_schema ["properties" ].keys ()) if "properties" in config_schema else []
409
+ context_schema_keys = []
410
+
411
+ if hasattr (self .graph , "context_schema" ) and self .graph .context_schema is not None :
412
+ context_schema = self .graph .context_schema ().schema ()
413
+ context_schema_keys = list (context_schema ["properties" ].keys ()) if "properties" in context_schema else []
414
+
404
415
405
416
return {
406
417
"input" : [* input_schema_keys , * self .constant_schema_keys ],
407
418
"output" : [* output_schema_keys , * self .constant_schema_keys ],
408
419
"config" : config_schema_keys ,
420
+ "context" : context_schema_keys ,
409
421
}
410
422
except Exception :
411
423
return {
412
424
"input" : self .constant_schema_keys ,
413
425
"output" : self .constant_schema_keys ,
414
426
"config" : [],
427
+ "context" : [],
415
428
}
416
429
417
430
def langgraph_default_merge_state (self , state : State , messages : List [BaseMessage ], input : RunAgentInput ) -> State :
@@ -744,3 +757,38 @@ def end_step(self):
744
757
self .active_run ["node_name" ] = None
745
758
self .active_step = None
746
759
return dispatch
760
+
761
+ # Check if some kwargs are enabled per LG version, to "catch all versions" and backwards compatibility
762
+ def get_stream_kwargs (
763
+ self ,
764
+ input : Any ,
765
+ subgraphs : bool = False ,
766
+ version : Literal ["v1" , "v2" ] = "v2" ,
767
+ config : Optional [RunnableConfig ] = None ,
768
+ context : Optional [Dict [str , Any ]] = None ,
769
+ fork : Optional [Any ] = None ,
770
+ ):
771
+ kwargs = dict (
772
+ input = input ,
773
+ subgraphs = subgraphs ,
774
+ version = version ,
775
+ )
776
+
777
+ # Only add context if supported
778
+ sig = inspect .signature (self .graph .astream_events )
779
+ if 'context' in sig .parameters :
780
+ base_context = {}
781
+ if isinstance (config , dict ) and 'configurable' in config and isinstance (config ['configurable' ], dict ):
782
+ base_context .update (config ['configurable' ])
783
+ if context : # context might be None or {}
784
+ base_context .update (context )
785
+ if base_context : # only add if there's something to pass
786
+ kwargs ['context' ] = base_context
787
+
788
+ if config :
789
+ kwargs ['config' ] = config
790
+
791
+ if fork :
792
+ kwargs .update (fork )
793
+
794
+ return kwargs
0 commit comments