|
1 | 1 | import { Subject } from "rxjs"; |
2 | 2 | import { toArray } from "rxjs/operators"; |
3 | | -import { firstValueFrom } from "rxjs"; |
| 3 | +import { firstValueFrom, of } from "rxjs"; |
4 | 4 | import { |
5 | 5 | BaseEvent, |
6 | 6 | EventType, |
|
9 | 9 | TextMessageContentEvent, |
10 | 10 | TextMessageEndEvent, |
11 | 11 | RunAgentInput, |
| 12 | + RunFinishedEvent, |
12 | 13 | } from "@ag-ui/core"; |
13 | 14 | import { defaultApplyEvents } from "../default"; |
14 | 15 | import { AbstractAgent } from "@/agent"; |
@@ -188,4 +189,59 @@ describe("defaultApplyEvents with text messages", () => { |
188 | 189 | // Verify no additional updates after either TEXT_MESSAGE_END |
189 | 190 | expect(stateUpdates.length).toBe(4); |
190 | 191 | }); |
| 192 | + |
| 193 | + it("should emit a messages snapshot when the run finishes", async () => { |
| 194 | + const initialState: RunAgentInput = { |
| 195 | + messages: [], |
| 196 | + state: {}, |
| 197 | + threadId: "test-thread", |
| 198 | + runId: "test-run", |
| 199 | + tools: [], |
| 200 | + context: [], |
| 201 | + }; |
| 202 | + |
| 203 | + const subscriber = { |
| 204 | + onMessagesSnapshotEvent: jest.fn(), |
| 205 | + onEvent: jest.fn(), |
| 206 | + }; |
| 207 | + |
| 208 | + const events: BaseEvent[] = [ |
| 209 | + { |
| 210 | + type: EventType.TEXT_MESSAGE_START, |
| 211 | + messageId: "msg-1", |
| 212 | + role: "assistant", |
| 213 | + } as TextMessageStartEvent, |
| 214 | + { |
| 215 | + type: EventType.TEXT_MESSAGE_CONTENT, |
| 216 | + messageId: "msg-1", |
| 217 | + delta: "Hello world!", |
| 218 | + } as TextMessageContentEvent, |
| 219 | + { |
| 220 | + type: EventType.TEXT_MESSAGE_END, |
| 221 | + messageId: "msg-1", |
| 222 | + } as TextMessageEndEvent, |
| 223 | + { |
| 224 | + type: EventType.RUN_FINISHED, |
| 225 | + threadId: "test-thread", |
| 226 | + runId: "test-run", |
| 227 | + } as RunFinishedEvent, |
| 228 | + ]; |
| 229 | + |
| 230 | + const result$ = defaultApplyEvents(initialState, of(...events), FAKE_AGENT, [subscriber as any]); |
| 231 | + const mutations = await firstValueFrom(result$.pipe(toArray())); |
| 232 | + |
| 233 | + expect(subscriber.onMessagesSnapshotEvent).toHaveBeenCalledTimes(1); |
| 234 | + const snapshotArgs = subscriber.onMessagesSnapshotEvent.mock.calls[0][0]; |
| 235 | + expect(snapshotArgs.event.type).toBe(EventType.MESSAGES_SNAPSHOT); |
| 236 | + expect(snapshotArgs.event.messages).toHaveLength(1); |
| 237 | + expect(snapshotArgs.event.messages?.[0]?.content).toBe("Hello world!"); |
| 238 | + |
| 239 | + const eventTypes = subscriber.onEvent.mock.calls.map((call) => call[0].event.type); |
| 240 | + expect(eventTypes[eventTypes.length - 2]).toBe(EventType.MESSAGES_SNAPSHOT); |
| 241 | + expect(eventTypes[eventTypes.length - 1]).toBe(EventType.RUN_FINISHED); |
| 242 | + |
| 243 | + const finalMutation = mutations[mutations.length - 1]; |
| 244 | + expect(finalMutation.messages).toHaveLength(1); |
| 245 | + expect(finalMutation.messages?.[0]?.content).toBe("Hello world!"); |
| 246 | + }); |
191 | 247 | }); |
0 commit comments