Skip to content

Commit edb3a1f

Browse files
committed
wip
1 parent ce29ad8 commit edb3a1f

File tree

2 files changed

+368
-2
lines changed

2 files changed

+368
-2
lines changed
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import { Middleware, EventWithState } from "../middleware";
2+
import { AbstractAgent } from "@/agent";
3+
import {
4+
RunAgentInput,
5+
BaseEvent,
6+
EventType,
7+
TextMessageStartEvent,
8+
TextMessageContentEvent,
9+
StateSnapshotEvent,
10+
StateDeltaEvent,
11+
MessagesSnapshotEvent,
12+
ToolCallStartEvent,
13+
ToolCallArgsEvent,
14+
ToolCallEndEvent,
15+
} from "@ag-ui/core";
16+
import { Observable, from } from "rxjs";
17+
import { map, toArray } from "rxjs/operators";
18+
19+
// Mock agent for testing
20+
class MockAgent extends AbstractAgent {
21+
constructor(private events: BaseEvent[]) {
22+
super();
23+
}
24+
25+
run(input: RunAgentInput): Observable<BaseEvent> {
26+
return from(this.events);
27+
}
28+
}
29+
30+
// Test middleware that uses runNextWithState
31+
class TestMiddleware extends Middleware {
32+
run(input: RunAgentInput, next: AbstractAgent): Observable<BaseEvent> {
33+
return this.runNextWithState(input, next).pipe(
34+
map(({ event }) => event)
35+
);
36+
}
37+
38+
// Expose for testing
39+
testRunNextWithState(
40+
input: RunAgentInput,
41+
next: AbstractAgent
42+
): Observable<EventWithState> {
43+
return this.runNextWithState(input, next);
44+
}
45+
}
46+
47+
describe("Middleware.runNextWithState", () => {
48+
it("should track messages as they are built", async () => {
49+
const events: BaseEvent[] = [
50+
{
51+
type: EventType.TEXT_MESSAGE_START,
52+
messageId: "msg1",
53+
role: "assistant",
54+
} as TextMessageStartEvent,
55+
{
56+
type: EventType.TEXT_MESSAGE_CONTENT,
57+
messageId: "msg1",
58+
delta: "Hello",
59+
} as TextMessageContentEvent,
60+
{
61+
type: EventType.TEXT_MESSAGE_CONTENT,
62+
messageId: "msg1",
63+
delta: " world",
64+
} as TextMessageContentEvent,
65+
];
66+
67+
const agent = new MockAgent(events);
68+
const middleware = new TestMiddleware();
69+
const input: RunAgentInput = { messages: [], state: {} };
70+
71+
const results = await middleware
72+
.testRunNextWithState(input, agent)
73+
.pipe(toArray())
74+
.toPromise();
75+
76+
expect(results).toHaveLength(3);
77+
78+
// After TEXT_MESSAGE_START, should have one empty message
79+
expect(results![0].messages).toHaveLength(1);
80+
expect(results![0].messages[0].id).toBe("msg1");
81+
expect(results![0].messages[0].role).toBe("assistant");
82+
expect(results![0].messages[0].content).toBe("");
83+
84+
// After first content chunk
85+
expect(results![1].messages).toHaveLength(1);
86+
expect(results![1].messages[0].content).toBe("Hello");
87+
88+
// After second content chunk
89+
expect(results![2].messages).toHaveLength(1);
90+
expect(results![2].messages[0].content).toBe("Hello world");
91+
});
92+
93+
it("should track state changes", async () => {
94+
const events: BaseEvent[] = [
95+
{
96+
type: EventType.STATE_SNAPSHOT,
97+
snapshot: { counter: 0, name: "test" },
98+
} as StateSnapshotEvent,
99+
{
100+
type: EventType.STATE_DELTA,
101+
delta: [{ op: "replace", path: "/counter", value: 1 }],
102+
} as StateDeltaEvent,
103+
{
104+
type: EventType.STATE_DELTA,
105+
delta: [{ op: "add", path: "/newField", value: "added" }],
106+
} as StateDeltaEvent,
107+
];
108+
109+
const agent = new MockAgent(events);
110+
const middleware = new TestMiddleware();
111+
const input: RunAgentInput = { messages: [], state: {} };
112+
113+
const results = await middleware
114+
.testRunNextWithState(input, agent)
115+
.pipe(toArray())
116+
.toPromise();
117+
118+
expect(results).toHaveLength(3);
119+
120+
// After STATE_SNAPSHOT
121+
expect(results![0].state).toEqual({ counter: 0, name: "test" });
122+
123+
// After first STATE_DELTA
124+
expect(results![1].state).toEqual({ counter: 1, name: "test" });
125+
126+
// After second STATE_DELTA
127+
expect(results![2].state).toEqual({
128+
counter: 1,
129+
name: "test",
130+
newField: "added",
131+
});
132+
});
133+
134+
it("should handle MESSAGES_SNAPSHOT", async () => {
135+
const events: BaseEvent[] = [
136+
{
137+
type: EventType.TEXT_MESSAGE_START,
138+
messageId: "msg1",
139+
role: "user",
140+
} as TextMessageStartEvent,
141+
{
142+
type: EventType.TEXT_MESSAGE_CONTENT,
143+
messageId: "msg1",
144+
delta: "First",
145+
} as TextMessageContentEvent,
146+
{
147+
type: EventType.MESSAGES_SNAPSHOT,
148+
messages: [
149+
{ id: "old1", role: "assistant", content: "Previous message" },
150+
{ id: "old2", role: "user", content: "Another message" },
151+
],
152+
} as MessagesSnapshotEvent,
153+
];
154+
155+
const agent = new MockAgent(events);
156+
const middleware = new TestMiddleware();
157+
const input: RunAgentInput = { messages: [], state: {} };
158+
159+
const results = await middleware
160+
.testRunNextWithState(input, agent)
161+
.pipe(toArray())
162+
.toPromise();
163+
164+
expect(results).toHaveLength(3);
165+
166+
// After building a message
167+
expect(results![1].messages).toHaveLength(1);
168+
expect(results![1].messages[0].content).toBe("First");
169+
170+
// After MESSAGES_SNAPSHOT - replaces all messages
171+
expect(results![2].messages).toHaveLength(2);
172+
expect(results![2].messages[0].id).toBe("old1");
173+
expect(results![2].messages[1].id).toBe("old2");
174+
});
175+
176+
it("should track tool calls", async () => {
177+
const events: BaseEvent[] = [
178+
{
179+
type: EventType.TOOL_CALL_START,
180+
toolCallId: "tool1",
181+
toolCallName: "calculator",
182+
parentMessageId: "msg1",
183+
} as ToolCallStartEvent,
184+
{
185+
type: EventType.TOOL_CALL_ARGS,
186+
toolCallId: "tool1",
187+
delta: '{"operation": "add"',
188+
} as ToolCallArgsEvent,
189+
{
190+
type: EventType.TOOL_CALL_ARGS,
191+
toolCallId: "tool1",
192+
delta: ', "values": [1, 2]}',
193+
} as ToolCallArgsEvent,
194+
{
195+
type: EventType.TOOL_CALL_END,
196+
toolCallId: "tool1",
197+
} as ToolCallEndEvent,
198+
];
199+
200+
const agent = new MockAgent(events);
201+
const middleware = new TestMiddleware();
202+
const input: RunAgentInput = { messages: [], state: {} };
203+
204+
const results = await middleware
205+
.testRunNextWithState(input, agent)
206+
.pipe(toArray())
207+
.toPromise();
208+
209+
expect(results).toHaveLength(4);
210+
211+
// After TOOL_CALL_START
212+
expect(results![0].messages).toHaveLength(1);
213+
expect(results![0].messages[0].role).toBe("assistant");
214+
const msg1 = results![0].messages[0] as any;
215+
expect(msg1.toolCalls).toHaveLength(1);
216+
expect(msg1.toolCalls[0].id).toBe("tool1");
217+
expect(msg1.toolCalls[0].type).toBe("function");
218+
expect(msg1.toolCalls[0].function.name).toBe("calculator");
219+
220+
// After args accumulation
221+
const msg3 = results![2].messages[0] as any;
222+
expect(msg3.toolCalls[0].function.arguments).toBe('{"operation": "add", "values": [1, 2]}');
223+
224+
// After TOOL_CALL_END - args remain as string (defaultApplyEvents doesn't parse them)
225+
const msg4 = results![3].messages[0] as any;
226+
expect(msg4.toolCalls[0].function.arguments).toBe('{"operation": "add", "values": [1, 2]}');
227+
});
228+
229+
it("should preserve initial state and messages", async () => {
230+
const events: BaseEvent[] = [
231+
{
232+
type: EventType.TEXT_MESSAGE_START,
233+
messageId: "new1",
234+
role: "assistant",
235+
} as TextMessageStartEvent,
236+
{
237+
type: EventType.STATE_DELTA,
238+
delta: [{ op: "add", path: "/newField", value: 42 }],
239+
} as StateDeltaEvent,
240+
];
241+
242+
const agent = new MockAgent(events);
243+
const middleware = new TestMiddleware();
244+
245+
const input: RunAgentInput = {
246+
messages: [
247+
{ id: "existing1", role: "user", content: "Existing message" },
248+
],
249+
state: { existingField: "hello" },
250+
};
251+
252+
const results = await middleware
253+
.testRunNextWithState(input, agent)
254+
.pipe(toArray())
255+
.toPromise();
256+
257+
expect(results).toHaveLength(2);
258+
259+
// Should preserve existing message and add new one
260+
expect(results![0].messages).toHaveLength(2);
261+
expect(results![0].messages[0].id).toBe("existing1");
262+
expect(results![0].messages[1].id).toBe("new1");
263+
264+
// Should preserve existing state and add new field
265+
expect(results![1].state).toEqual({
266+
existingField: "hello",
267+
newField: 42,
268+
});
269+
});
270+
271+
it("should provide immutable snapshots", async () => {
272+
const events: BaseEvent[] = [
273+
{
274+
type: EventType.TEXT_MESSAGE_START,
275+
messageId: "msg1",
276+
role: "assistant",
277+
} as TextMessageStartEvent,
278+
{
279+
type: EventType.STATE_SNAPSHOT,
280+
snapshot: { value: 1 },
281+
} as StateSnapshotEvent,
282+
];
283+
284+
const agent = new MockAgent(events);
285+
const middleware = new TestMiddleware();
286+
const input: RunAgentInput = { messages: [], state: {} };
287+
288+
const results = await middleware
289+
.testRunNextWithState(input, agent)
290+
.pipe(toArray())
291+
.toPromise();
292+
293+
// Modify returned state/messages - should not affect next results
294+
results![0].messages[0].content = "MODIFIED";
295+
results![0].state.hacked = true;
296+
297+
// Second result should not be affected
298+
expect(results![1].messages[0].content).toBe("");
299+
expect(results![1].state).toEqual({ value: 1 });
300+
expect(results![1].state.hacked).toBeUndefined();
301+
});
302+
});

typescript-sdk/packages/client/src/middleware/middleware.ts

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,75 @@
11
import { AbstractAgent } from "@/agent";
2-
import { RunAgentInput, BaseEvent } from "@ag-ui/core";
3-
import { Observable } from "rxjs";
2+
import { RunAgentInput, BaseEvent, Message } from "@ag-ui/core";
3+
import { Observable, ReplaySubject } from "rxjs";
4+
import { concatMap } from "rxjs/operators";
5+
import { transformChunks } from "@/chunks";
6+
import { defaultApplyEvents } from "@/apply";
7+
import { structuredClone_ } from "@/utils";
48

59
export type MiddlewareFunction = (input: RunAgentInput, next: AbstractAgent) => Observable<BaseEvent>;
610

11+
export interface EventWithState {
12+
event: BaseEvent;
13+
messages: Message[];
14+
state: any;
15+
}
16+
717
export abstract class Middleware {
818
abstract run(input: RunAgentInput, next: AbstractAgent): Observable<BaseEvent>;
19+
20+
/**
21+
* Runs the next agent in the chain with automatic chunk transformation.
22+
*/
23+
protected runNext(input: RunAgentInput, next: AbstractAgent): Observable<BaseEvent> {
24+
return next.run(input).pipe(
25+
transformChunks(false) // Always transform chunks to full events
26+
);
27+
}
28+
29+
/**
30+
* Runs the next agent and tracks state, providing current messages and state with each event.
31+
* The messages and state represent the state AFTER the event has been applied.
32+
*/
33+
protected runNextWithState(
34+
input: RunAgentInput,
35+
next: AbstractAgent
36+
): Observable<EventWithState> {
37+
let currentMessages = structuredClone_(input.messages || []);
38+
let currentState = structuredClone_(input.state || {});
39+
40+
// Use a ReplaySubject to feed events one by one
41+
const eventSubject = new ReplaySubject<BaseEvent>();
42+
43+
// Set up defaultApplyEvents to process events
44+
const mutations$ = defaultApplyEvents(input, eventSubject, next, []);
45+
46+
// Subscribe to track state changes
47+
mutations$.subscribe(mutation => {
48+
if (mutation.messages !== undefined) {
49+
currentMessages = mutation.messages;
50+
}
51+
if (mutation.state !== undefined) {
52+
currentState = mutation.state;
53+
}
54+
});
55+
56+
return this.runNext(input, next).pipe(
57+
concatMap(async event => {
58+
// Feed the event to defaultApplyEvents and wait for it to process
59+
eventSubject.next(event);
60+
61+
// Give defaultApplyEvents a chance to process
62+
await new Promise(resolve => setTimeout(resolve, 0));
63+
64+
// Return event with current state
65+
return {
66+
event,
67+
messages: structuredClone_(currentMessages),
68+
state: structuredClone_(currentState)
69+
};
70+
})
71+
);
72+
}
973
}
1074

1175
// Wrapper class to convert a function into a Middleware instance

0 commit comments

Comments
 (0)