1+ import { Middleware , EventWithState } from "../middleware" ;
2+ import { AbstractAgent } from "@/agent" ;
3+ import {
4+ RunAgentInput ,
5+ BaseEvent ,
6+ EventType ,
7+ TextMessageStartEvent ,
8+ TextMessageContentEvent ,
9+ StateSnapshotEvent ,
10+ StateDeltaEvent ,
11+ MessagesSnapshotEvent ,
12+ ToolCallStartEvent ,
13+ ToolCallArgsEvent ,
14+ ToolCallEndEvent ,
15+ } from "@ag-ui/core" ;
16+ import { Observable , from } from "rxjs" ;
17+ import { map , toArray } from "rxjs/operators" ;
18+
19+ // Mock agent for testing
20+ class MockAgent extends AbstractAgent {
21+ constructor ( private events : BaseEvent [ ] ) {
22+ super ( ) ;
23+ }
24+
25+ run ( input : RunAgentInput ) : Observable < BaseEvent > {
26+ return from ( this . events ) ;
27+ }
28+ }
29+
30+ // Test middleware that uses runNextWithState
31+ class TestMiddleware extends Middleware {
32+ run ( input : RunAgentInput , next : AbstractAgent ) : Observable < BaseEvent > {
33+ return this . runNextWithState ( input , next ) . pipe (
34+ map ( ( { event } ) => event )
35+ ) ;
36+ }
37+
38+ // Expose for testing
39+ testRunNextWithState (
40+ input : RunAgentInput ,
41+ next : AbstractAgent
42+ ) : Observable < EventWithState > {
43+ return this . runNextWithState ( input , next ) ;
44+ }
45+ }
46+
47+ describe ( "Middleware.runNextWithState" , ( ) => {
48+ it ( "should track messages as they are built" , async ( ) => {
49+ const events : BaseEvent [ ] = [
50+ {
51+ type : EventType . TEXT_MESSAGE_START ,
52+ messageId : "msg1" ,
53+ role : "assistant" ,
54+ } as TextMessageStartEvent ,
55+ {
56+ type : EventType . TEXT_MESSAGE_CONTENT ,
57+ messageId : "msg1" ,
58+ delta : "Hello" ,
59+ } as TextMessageContentEvent ,
60+ {
61+ type : EventType . TEXT_MESSAGE_CONTENT ,
62+ messageId : "msg1" ,
63+ delta : " world" ,
64+ } as TextMessageContentEvent ,
65+ ] ;
66+
67+ const agent = new MockAgent ( events ) ;
68+ const middleware = new TestMiddleware ( ) ;
69+ const input : RunAgentInput = { messages : [ ] , state : { } } ;
70+
71+ const results = await middleware
72+ . testRunNextWithState ( input , agent )
73+ . pipe ( toArray ( ) )
74+ . toPromise ( ) ;
75+
76+ expect ( results ) . toHaveLength ( 3 ) ;
77+
78+ // After TEXT_MESSAGE_START, should have one empty message
79+ expect ( results ! [ 0 ] . messages ) . toHaveLength ( 1 ) ;
80+ expect ( results ! [ 0 ] . messages [ 0 ] . id ) . toBe ( "msg1" ) ;
81+ expect ( results ! [ 0 ] . messages [ 0 ] . role ) . toBe ( "assistant" ) ;
82+ expect ( results ! [ 0 ] . messages [ 0 ] . content ) . toBe ( "" ) ;
83+
84+ // After first content chunk
85+ expect ( results ! [ 1 ] . messages ) . toHaveLength ( 1 ) ;
86+ expect ( results ! [ 1 ] . messages [ 0 ] . content ) . toBe ( "Hello" ) ;
87+
88+ // After second content chunk
89+ expect ( results ! [ 2 ] . messages ) . toHaveLength ( 1 ) ;
90+ expect ( results ! [ 2 ] . messages [ 0 ] . content ) . toBe ( "Hello world" ) ;
91+ } ) ;
92+
93+ it ( "should track state changes" , async ( ) => {
94+ const events : BaseEvent [ ] = [
95+ {
96+ type : EventType . STATE_SNAPSHOT ,
97+ snapshot : { counter : 0 , name : "test" } ,
98+ } as StateSnapshotEvent ,
99+ {
100+ type : EventType . STATE_DELTA ,
101+ delta : [ { op : "replace" , path : "/counter" , value : 1 } ] ,
102+ } as StateDeltaEvent ,
103+ {
104+ type : EventType . STATE_DELTA ,
105+ delta : [ { op : "add" , path : "/newField" , value : "added" } ] ,
106+ } as StateDeltaEvent ,
107+ ] ;
108+
109+ const agent = new MockAgent ( events ) ;
110+ const middleware = new TestMiddleware ( ) ;
111+ const input : RunAgentInput = { messages : [ ] , state : { } } ;
112+
113+ const results = await middleware
114+ . testRunNextWithState ( input , agent )
115+ . pipe ( toArray ( ) )
116+ . toPromise ( ) ;
117+
118+ expect ( results ) . toHaveLength ( 3 ) ;
119+
120+ // After STATE_SNAPSHOT
121+ expect ( results ! [ 0 ] . state ) . toEqual ( { counter : 0 , name : "test" } ) ;
122+
123+ // After first STATE_DELTA
124+ expect ( results ! [ 1 ] . state ) . toEqual ( { counter : 1 , name : "test" } ) ;
125+
126+ // After second STATE_DELTA
127+ expect ( results ! [ 2 ] . state ) . toEqual ( {
128+ counter : 1 ,
129+ name : "test" ,
130+ newField : "added" ,
131+ } ) ;
132+ } ) ;
133+
134+ it ( "should handle MESSAGES_SNAPSHOT" , async ( ) => {
135+ const events : BaseEvent [ ] = [
136+ {
137+ type : EventType . TEXT_MESSAGE_START ,
138+ messageId : "msg1" ,
139+ role : "user" ,
140+ } as TextMessageStartEvent ,
141+ {
142+ type : EventType . TEXT_MESSAGE_CONTENT ,
143+ messageId : "msg1" ,
144+ delta : "First" ,
145+ } as TextMessageContentEvent ,
146+ {
147+ type : EventType . MESSAGES_SNAPSHOT ,
148+ messages : [
149+ { id : "old1" , role : "assistant" , content : "Previous message" } ,
150+ { id : "old2" , role : "user" , content : "Another message" } ,
151+ ] ,
152+ } as MessagesSnapshotEvent ,
153+ ] ;
154+
155+ const agent = new MockAgent ( events ) ;
156+ const middleware = new TestMiddleware ( ) ;
157+ const input : RunAgentInput = { messages : [ ] , state : { } } ;
158+
159+ const results = await middleware
160+ . testRunNextWithState ( input , agent )
161+ . pipe ( toArray ( ) )
162+ . toPromise ( ) ;
163+
164+ expect ( results ) . toHaveLength ( 3 ) ;
165+
166+ // After building a message
167+ expect ( results ! [ 1 ] . messages ) . toHaveLength ( 1 ) ;
168+ expect ( results ! [ 1 ] . messages [ 0 ] . content ) . toBe ( "First" ) ;
169+
170+ // After MESSAGES_SNAPSHOT - replaces all messages
171+ expect ( results ! [ 2 ] . messages ) . toHaveLength ( 2 ) ;
172+ expect ( results ! [ 2 ] . messages [ 0 ] . id ) . toBe ( "old1" ) ;
173+ expect ( results ! [ 2 ] . messages [ 1 ] . id ) . toBe ( "old2" ) ;
174+ } ) ;
175+
176+ it ( "should track tool calls" , async ( ) => {
177+ const events : BaseEvent [ ] = [
178+ {
179+ type : EventType . TOOL_CALL_START ,
180+ toolCallId : "tool1" ,
181+ toolCallName : "calculator" ,
182+ parentMessageId : "msg1" ,
183+ } as ToolCallStartEvent ,
184+ {
185+ type : EventType . TOOL_CALL_ARGS ,
186+ toolCallId : "tool1" ,
187+ delta : '{"operation": "add"' ,
188+ } as ToolCallArgsEvent ,
189+ {
190+ type : EventType . TOOL_CALL_ARGS ,
191+ toolCallId : "tool1" ,
192+ delta : ', "values": [1, 2]}' ,
193+ } as ToolCallArgsEvent ,
194+ {
195+ type : EventType . TOOL_CALL_END ,
196+ toolCallId : "tool1" ,
197+ } as ToolCallEndEvent ,
198+ ] ;
199+
200+ const agent = new MockAgent ( events ) ;
201+ const middleware = new TestMiddleware ( ) ;
202+ const input : RunAgentInput = { messages : [ ] , state : { } } ;
203+
204+ const results = await middleware
205+ . testRunNextWithState ( input , agent )
206+ . pipe ( toArray ( ) )
207+ . toPromise ( ) ;
208+
209+ expect ( results ) . toHaveLength ( 4 ) ;
210+
211+ // After TOOL_CALL_START
212+ expect ( results ! [ 0 ] . messages ) . toHaveLength ( 1 ) ;
213+ expect ( results ! [ 0 ] . messages [ 0 ] . role ) . toBe ( "assistant" ) ;
214+ const msg1 = results ! [ 0 ] . messages [ 0 ] as any ;
215+ expect ( msg1 . toolCalls ) . toHaveLength ( 1 ) ;
216+ expect ( msg1 . toolCalls [ 0 ] . id ) . toBe ( "tool1" ) ;
217+ expect ( msg1 . toolCalls [ 0 ] . type ) . toBe ( "function" ) ;
218+ expect ( msg1 . toolCalls [ 0 ] . function . name ) . toBe ( "calculator" ) ;
219+
220+ // After args accumulation
221+ const msg3 = results ! [ 2 ] . messages [ 0 ] as any ;
222+ expect ( msg3 . toolCalls [ 0 ] . function . arguments ) . toBe ( '{"operation": "add", "values": [1, 2]}' ) ;
223+
224+ // After TOOL_CALL_END - args remain as string (defaultApplyEvents doesn't parse them)
225+ const msg4 = results ! [ 3 ] . messages [ 0 ] as any ;
226+ expect ( msg4 . toolCalls [ 0 ] . function . arguments ) . toBe ( '{"operation": "add", "values": [1, 2]}' ) ;
227+ } ) ;
228+
229+ it ( "should preserve initial state and messages" , async ( ) => {
230+ const events : BaseEvent [ ] = [
231+ {
232+ type : EventType . TEXT_MESSAGE_START ,
233+ messageId : "new1" ,
234+ role : "assistant" ,
235+ } as TextMessageStartEvent ,
236+ {
237+ type : EventType . STATE_DELTA ,
238+ delta : [ { op : "add" , path : "/newField" , value : 42 } ] ,
239+ } as StateDeltaEvent ,
240+ ] ;
241+
242+ const agent = new MockAgent ( events ) ;
243+ const middleware = new TestMiddleware ( ) ;
244+
245+ const input : RunAgentInput = {
246+ messages : [
247+ { id : "existing1" , role : "user" , content : "Existing message" } ,
248+ ] ,
249+ state : { existingField : "hello" } ,
250+ } ;
251+
252+ const results = await middleware
253+ . testRunNextWithState ( input , agent )
254+ . pipe ( toArray ( ) )
255+ . toPromise ( ) ;
256+
257+ expect ( results ) . toHaveLength ( 2 ) ;
258+
259+ // Should preserve existing message and add new one
260+ expect ( results ! [ 0 ] . messages ) . toHaveLength ( 2 ) ;
261+ expect ( results ! [ 0 ] . messages [ 0 ] . id ) . toBe ( "existing1" ) ;
262+ expect ( results ! [ 0 ] . messages [ 1 ] . id ) . toBe ( "new1" ) ;
263+
264+ // Should preserve existing state and add new field
265+ expect ( results ! [ 1 ] . state ) . toEqual ( {
266+ existingField : "hello" ,
267+ newField : 42 ,
268+ } ) ;
269+ } ) ;
270+
271+ it ( "should provide immutable snapshots" , async ( ) => {
272+ const events : BaseEvent [ ] = [
273+ {
274+ type : EventType . TEXT_MESSAGE_START ,
275+ messageId : "msg1" ,
276+ role : "assistant" ,
277+ } as TextMessageStartEvent ,
278+ {
279+ type : EventType . STATE_SNAPSHOT ,
280+ snapshot : { value : 1 } ,
281+ } as StateSnapshotEvent ,
282+ ] ;
283+
284+ const agent = new MockAgent ( events ) ;
285+ const middleware = new TestMiddleware ( ) ;
286+ const input : RunAgentInput = { messages : [ ] , state : { } } ;
287+
288+ const results = await middleware
289+ . testRunNextWithState ( input , agent )
290+ . pipe ( toArray ( ) )
291+ . toPromise ( ) ;
292+
293+ // Modify returned state/messages - should not affect next results
294+ results ! [ 0 ] . messages [ 0 ] . content = "MODIFIED" ;
295+ results ! [ 0 ] . state . hacked = true ;
296+
297+ // Second result should not be affected
298+ expect ( results ! [ 1 ] . messages [ 0 ] . content ) . toBe ( "" ) ;
299+ expect ( results ! [ 1 ] . state ) . toEqual ( { value : 1 } ) ;
300+ expect ( results ! [ 1 ] . state . hacked ) . toBeUndefined ( ) ;
301+ } ) ;
302+ } ) ;
0 commit comments