@@ -24,6 +24,8 @@ import {
2424 RunMetadata ,
2525 PredictStateTool ,
2626 LangGraphReasoning ,
27+ StateEnrichment ,
28+ LangGraphTool ,
2729} from "./types" ;
2830import {
2931 AbstractAgent ,
@@ -181,7 +183,7 @@ export class LangGraphAgent extends AbstractAgent {
181183 }
182184
183185 async prepareRegenerateStream ( input : RegenerateInput , streamMode : StreamMode | StreamMode [ ] ) {
184- const { threadId, messageCheckpoint, tools } = input ;
186+ const { threadId, messageCheckpoint } = input ;
185187
186188 const timeTravelCheckpoint = await this . getCheckpointByMessage (
187189 messageCheckpoint ! . id ! ,
@@ -196,7 +198,7 @@ export class LangGraphAgent extends AbstractAgent {
196198 }
197199
198200 const fork = await this . client . threads . updateState ( threadId , {
199- values : this . langGraphDefaultMergeState ( timeTravelCheckpoint . values , [ ] , tools ) ,
201+ values : this . langGraphDefaultMergeState ( timeTravelCheckpoint . values , [ ] , input ) ,
200202 checkpointId : timeTravelCheckpoint . checkpoint . checkpoint_id ! ,
201203 asNode : timeTravelCheckpoint . next ?. [ 0 ] ?? "__start__" ,
202204 } ) ;
@@ -206,7 +208,7 @@ export class LangGraphAgent extends AbstractAgent {
206208 input : this . langGraphDefaultMergeState (
207209 timeTravelCheckpoint . values ,
208210 [ messageCheckpoint ] ,
209- tools ,
211+ input ,
210212 ) ,
211213 // @ts -ignore
212214 checkpointId : fork . checkpoint . checkpoint_id ! ,
@@ -255,14 +257,14 @@ export class LangGraphAgent extends AbstractAgent {
255257 const stateValuesDiff = this . langGraphDefaultMergeState (
256258 { ...inputState , messages : agentStateMessages } ,
257259 inputMessagesToLangchain ,
258- tools ,
260+ input ,
259261 ) ;
260262 // Messages are a combination of existing messages in state + everything that was newly sent
261263 let threadState = {
262264 ...agentState ,
263265 values : {
264266 ...stateValuesDiff ,
265- messages : [ ...agentStateMessages , ...stateValuesDiff . messages ] ,
267+ messages : [ ...agentStateMessages , ...( stateValuesDiff . messages ?? [ ] ) ] ,
266268 } ,
267269 } ;
268270 let stateValues = threadState . values ;
@@ -968,7 +970,7 @@ export class LangGraphAgent extends AbstractAgent {
968970 }
969971 }
970972
971- langGraphDefaultMergeState ( state : State , messages : LangGraphMessage [ ] , tools : any ) : State {
973+ langGraphDefaultMergeState ( state : State , messages : LangGraphMessage [ ] , input : RunAgentExtendedInput ) : State < StateEnrichment > {
972974 if ( messages . length > 0 && "role" in messages [ 0 ] && messages [ 0 ] . role === "system" ) {
973975 // remove system message
974976 messages = messages . slice ( 1 ) ;
@@ -980,7 +982,7 @@ export class LangGraphAgent extends AbstractAgent {
980982
981983 const newMessages = messages . filter ( ( message ) => ! existingMessageIds . has ( message . id ) ) ;
982984
983- const langGraphTools = [ ...( state . tools ?? [ ] ) , ...( tools ?? [ ] ) ] . map ( ( tool ) => {
985+ const langGraphTools : LangGraphTool [ ] = [ ...( state . tools ?? [ ] ) , ...( input . tools ?? [ ] ) ] . map ( ( tool ) => {
984986 if ( tool . type ) {
985987 return tool ;
986988 }
@@ -999,6 +1001,10 @@ export class LangGraphAgent extends AbstractAgent {
9991001 ...state ,
10001002 messages : newMessages ,
10011003 tools : langGraphTools ,
1004+ 'ag-ui' : {
1005+ tools : langGraphTools ,
1006+ context : input . context ,
1007+ }
10021008 } ;
10031009 }
10041010
0 commit comments