@@ -24,6 +24,8 @@ import {
24
24
RunMetadata ,
25
25
PredictStateTool ,
26
26
LangGraphReasoning ,
27
+ StateEnrichment ,
28
+ LangGraphTool ,
27
29
} from "./types" ;
28
30
import {
29
31
AbstractAgent ,
@@ -181,7 +183,7 @@ export class LangGraphAgent extends AbstractAgent {
181
183
}
182
184
183
185
async prepareRegenerateStream ( input : RegenerateInput , streamMode : StreamMode | StreamMode [ ] ) {
184
- const { threadId, messageCheckpoint, tools } = input ;
186
+ const { threadId, messageCheckpoint } = input ;
185
187
186
188
const timeTravelCheckpoint = await this . getCheckpointByMessage (
187
189
messageCheckpoint ! . id ! ,
@@ -196,7 +198,7 @@ export class LangGraphAgent extends AbstractAgent {
196
198
}
197
199
198
200
const fork = await this . client . threads . updateState ( threadId , {
199
- values : this . langGraphDefaultMergeState ( timeTravelCheckpoint . values , [ ] , tools ) ,
201
+ values : this . langGraphDefaultMergeState ( timeTravelCheckpoint . values , [ ] , input ) ,
200
202
checkpointId : timeTravelCheckpoint . checkpoint . checkpoint_id ! ,
201
203
asNode : timeTravelCheckpoint . next ?. [ 0 ] ?? "__start__" ,
202
204
} ) ;
@@ -206,7 +208,7 @@ export class LangGraphAgent extends AbstractAgent {
206
208
input : this . langGraphDefaultMergeState (
207
209
timeTravelCheckpoint . values ,
208
210
[ messageCheckpoint ] ,
209
- tools ,
211
+ input ,
210
212
) ,
211
213
// @ts -ignore
212
214
checkpointId : fork . checkpoint . checkpoint_id ! ,
@@ -255,14 +257,14 @@ export class LangGraphAgent extends AbstractAgent {
255
257
const stateValuesDiff = this . langGraphDefaultMergeState (
256
258
{ ...inputState , messages : agentStateMessages } ,
257
259
inputMessagesToLangchain ,
258
- tools ,
260
+ input ,
259
261
) ;
260
262
// Messages are a combination of existing messages in state + everything that was newly sent
261
263
let threadState = {
262
264
...agentState ,
263
265
values : {
264
266
...stateValuesDiff ,
265
- messages : [ ...agentStateMessages , ...stateValuesDiff . messages ] ,
267
+ messages : [ ...agentStateMessages , ...( stateValuesDiff . messages ?? [ ] ) ] ,
266
268
} ,
267
269
} ;
268
270
let stateValues = threadState . values ;
@@ -968,7 +970,7 @@ export class LangGraphAgent extends AbstractAgent {
968
970
}
969
971
}
970
972
971
- langGraphDefaultMergeState ( state : State , messages : LangGraphMessage [ ] , tools : any ) : State {
973
+ langGraphDefaultMergeState ( state : State , messages : LangGraphMessage [ ] , input : RunAgentExtendedInput ) : State < StateEnrichment > {
972
974
if ( messages . length > 0 && "role" in messages [ 0 ] && messages [ 0 ] . role === "system" ) {
973
975
// remove system message
974
976
messages = messages . slice ( 1 ) ;
@@ -980,7 +982,7 @@ export class LangGraphAgent extends AbstractAgent {
980
982
981
983
const newMessages = messages . filter ( ( message ) => ! existingMessageIds . has ( message . id ) ) ;
982
984
983
- const langGraphTools = [ ...( state . tools ?? [ ] ) , ...( tools ?? [ ] ) ] . map ( ( tool ) => {
985
+ const langGraphTools : LangGraphTool [ ] = [ ...( state . tools ?? [ ] ) , ...( input . tools ?? [ ] ) ] . map ( ( tool ) => {
984
986
if ( tool . type ) {
985
987
return tool ;
986
988
}
@@ -999,6 +1001,10 @@ export class LangGraphAgent extends AbstractAgent {
999
1001
...state ,
1000
1002
messages : newMessages ,
1001
1003
tools : langGraphTools ,
1004
+ 'ag-ui' : {
1005
+ tools : langGraphTools ,
1006
+ context : input . context ,
1007
+ }
1002
1008
} ;
1003
1009
}
1004
1010
0 commit comments