@@ -46,11 +46,7 @@ import type {
4646 JumpTo ,
4747 UserInput ,
4848} from "./types.js" ;
49- import type {
50- PrivateState ,
51- InvokeConfiguration ,
52- StreamConfiguration ,
53- } from "./runtime.js" ;
49+ import type { InvokeConfiguration , StreamConfiguration } from "./runtime.js" ;
5450import type {
5551 AgentMiddleware ,
5652 InferMiddlewareContextInputs ,
@@ -249,10 +245,20 @@ export class ReactAgent<
249245 throw new Error ( `Middleware ${ m . name } is defined multiple times` ) ;
250246 }
251247
248+ const getState = ( ) => {
249+ return {
250+ ...beforeAgentNode ?. getState ( ) ,
251+ ...beforeModelNode ?. getState ( ) ,
252+ ...afterModelNode ?. getState ( ) ,
253+ ...afterAgentNode ?. getState ( ) ,
254+ ...this . #agentNode. getState ( ) ,
255+ } ;
256+ } ;
257+
252258 middlewareNames . add ( m . name ) ;
253259 if ( m . beforeAgent ) {
254260 beforeAgentNode = new BeforeAgentNode ( m , {
255- getPrivateState : ( ) => this . #agentNode . getState ( ) . _privateState ,
261+ getState,
256262 } ) ;
257263 const name = `${ m . name } .before_agent` ;
258264 beforeAgentNodes . push ( {
@@ -268,7 +274,7 @@ export class ReactAgent<
268274 }
269275 if ( m . beforeModel ) {
270276 beforeModelNode = new BeforeModelNode ( m , {
271- getPrivateState : ( ) => this . #agentNode . getState ( ) . _privateState ,
277+ getState,
272278 } ) ;
273279 const name = `${ m . name } .before_model` ;
274280 beforeModelNodes . push ( {
@@ -284,7 +290,7 @@ export class ReactAgent<
284290 }
285291 if ( m . afterModel ) {
286292 afterModelNode = new AfterModelNode ( m , {
287- getPrivateState : ( ) => this . #agentNode . getState ( ) . _privateState ,
293+ getState,
288294 } ) ;
289295 const name = `${ m . name } .after_model` ;
290296 afterModelNodes . push ( {
@@ -300,7 +306,7 @@ export class ReactAgent<
300306 }
301307 if ( m . afterAgent ) {
302308 afterAgentNode = new AfterAgentNode ( m , {
303- getPrivateState : ( ) => this . #agentNode . getState ( ) . _privateState ,
309+ getState,
304310 } ) ;
305311 const name = `${ m . name } .after_agent` ;
306312 afterAgentNodes . push ( {
@@ -316,15 +322,7 @@ export class ReactAgent<
316322 }
317323
318324 if ( m . wrapModelCall ) {
319- wrapModelCallHookMiddleware . push ( [
320- m ,
321- ( ) => ( {
322- ...beforeAgentNode ?. getState ( ) ,
323- ...beforeModelNode ?. getState ( ) ,
324- ...afterModelNode ?. getState ( ) ,
325- ...afterAgentNode ?. getState ( ) ,
326- } ) ,
327- ] ) ;
325+ wrapModelCallHookMiddleware . push ( [ m , getState ] ) ;
328326 }
329327 }
330328
@@ -350,7 +348,6 @@ export class ReactAgent<
350348 const toolNode = new ToolNode ( toolClasses . filter ( isClientTool ) , {
351349 signal : this . options . signal ,
352350 wrapToolCall : wrapToolCallHandler ,
353- getPrivateState : ( ) => this . #agentNode. getState ( ) . _privateState ,
354351 } ) ;
355352 allNodeWorkflows . addNode ( "tools" , toolNode ) ;
356353 }
@@ -944,7 +941,8 @@ export class ReactAgent<
944941 * Initialize middleware states if not already present in the input state.
945942 */
946943 async #initializeMiddlewareStates(
947- state : InvokeStateParameter < StateSchema , TMiddleware >
944+ state : InvokeStateParameter < StateSchema , TMiddleware > ,
945+ config : RunnableConfig
948946 ) : Promise < InvokeStateParameter < StateSchema , TMiddleware > > {
949947 if (
950948 ! this . options . middleware ||
@@ -959,10 +957,13 @@ export class ReactAgent<
959957 this . options . middleware ,
960958 state
961959 ) ;
962- const updatedState = { ...state } as InvokeStateParameter <
963- StateSchema ,
964- TMiddleware
965- > ;
960+ const threadState = await this . #graph
961+ . getState ( config )
962+ . catch ( ( ) => ( { values : { } } ) ) ;
963+ const updatedState = {
964+ ...threadState . values ,
965+ ...state ,
966+ } as InvokeStateParameter < StateSchema , TMiddleware > ;
966967 if ( ! updatedState ) {
967968 return updatedState ;
968969 }
@@ -977,35 +978,6 @@ export class ReactAgent<
977978 return updatedState ;
978979 }
979980
980- /**
981- * Populate the private state of the agent node from the previous state.
982- */
983- async #populatePrivateState( config ?: RunnableConfig ) {
984- /**
985- * not needed if thread_id is not provided
986- */
987- if ( ! config ?. configurable ?. thread_id ) {
988- return ;
989- }
990- const prevState = ( await this . #graph. getState ( config as any ) ) as {
991- values : {
992- _privateState : PrivateState ;
993- } ;
994- } ;
995-
996- /**
997- * not need if state is empty
998- */
999- if ( ! prevState . values . _privateState ) {
1000- return ;
1001- }
1002-
1003- this . #agentNode. setState ( {
1004- structuredResponse : undefined ,
1005- _privateState : prevState . values . _privateState ,
1006- } ) ;
1007- }
1008-
1009981 /**
1010982 * Executes the agent with the given state and returns the final state after all processing.
1011983 *
@@ -1061,8 +1033,10 @@ export class ReactAgent<
10611033 StructuredResponseFormat ,
10621034 TMiddleware
10631035 > ;
1064- const initializedState = await this . #initializeMiddlewareStates( state ) ;
1065- await this . #populatePrivateState( config ) ;
1036+ const initializedState = await this . #initializeMiddlewareStates(
1037+ state ,
1038+ config as RunnableConfig
1039+ ) ;
10661040
10671041 return this . #graph. invoke (
10681042 initializedState ,
@@ -1120,7 +1094,10 @@ export class ReactAgent<
11201094 InferMiddlewareContextInputs < TMiddleware >
11211095 >
11221096 ) : Promise < IterableReadableStream < any > > {
1123- const initializedState = await this . #initializeMiddlewareStates( state ) ;
1097+ const initializedState = await this . #initializeMiddlewareStates(
1098+ state ,
1099+ config as RunnableConfig
1100+ ) ;
11241101 return this . #graph. stream ( initializedState , config as Record < string , any > ) ;
11251102 }
11261103
0 commit comments