diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts index e6506d17..63b91202 100644 --- a/packages/agent/src/index.ts +++ b/packages/agent/src/index.ts @@ -460,6 +460,8 @@ export interface BasicAgentConfiguration { } export class BasicAgent extends AbstractAgent { + private abortController?: AbortController; + constructor(private config: BasicAgentConfiguration) { super(); } @@ -620,6 +622,10 @@ export class BasicAgent extends AbstractAgent { const mcpClients: Array<{ close: () => Promise }> = []; (async () => { + const abortController = new AbortController(); + this.abortController = abortController; + let terminalEventEmitted = false; + try { // Add AG-UI state update tools streamTextParams.tools = { @@ -681,7 +687,7 @@ export class BasicAgent extends AbstractAgent { } // Call streamText and process the stream - const response = streamText(streamTextParams); + const response = streamText({ ...streamTextParams, abortSignal: abortController.signal }); let messageId = randomUUID(); @@ -848,32 +854,60 @@ export class BasicAgent extends AbstractAgent { runId: input.runId, }; subscriber.next(finishedEvent); + terminalEventEmitted = true; // Complete the observable subscriber.complete(); break; - case "error": + case "error": { + if (abortController.signal.aborted) { + break; + } const runErrorEvent: RunErrorEvent = { type: EventType.RUN_ERROR, message: part.error + "", }; subscriber.next(runErrorEvent); + terminalEventEmitted = true; // Handle error subscriber.error(part.error); break; + } } } - } catch (error) { - const runErrorEvent: RunErrorEvent = { - type: EventType.RUN_ERROR, - message: error + "", - }; - subscriber.next(runErrorEvent); - subscriber.error(error); + if (!terminalEventEmitted) { + if (abortController.signal.aborted) { + // Let the runner finalize the stream on stop requests so it can + // inject consistent closing events and a RUN_FINISHED marker. + } else { + const finishedEvent: RunFinishedEvent = { + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }; + subscriber.next(finishedEvent); + } + + terminalEventEmitted = true; + subscriber.complete(); + } + } catch (error) { + if (abortController.signal.aborted) { + subscriber.complete(); + } else { + const runErrorEvent: RunErrorEvent = { + type: EventType.RUN_ERROR, + message: error + "", + }; + subscriber.next(runErrorEvent); + terminalEventEmitted = true; + subscriber.error(error); + } } finally { + this.abortController = undefined; await Promise.all(mcpClients.map((client) => client.close())); } })(); @@ -891,4 +925,8 @@ export class BasicAgent extends AbstractAgent { clone() { return new BasicAgent(this.config); } + + abortRun(): void { + this.abortController?.abort(); + } } diff --git a/packages/core/src/agent.ts b/packages/core/src/agent.ts index e8ca6930..ac4b2b5a 100644 --- a/packages/core/src/agent.ts +++ b/packages/core/src/agent.ts @@ -8,8 +8,7 @@ import { } from "@ag-ui/client"; import { Observable } from "rxjs"; -export interface ProxiedCopilotRuntimeAgentConfig - extends Omit { +export interface ProxiedCopilotRuntimeAgentConfig extends Omit { runtimeUrl?: string; } @@ -24,11 +23,33 @@ export class ProxiedCopilotRuntimeAgent extends HttpAgent { this.runtimeUrl = config.runtimeUrl; } + abortRun(): void { + if (!this.runtimeUrl || !this.agentId || !this.threadId) { + return; + } + + if (typeof fetch === "undefined") { + return; + } + + const stopPath = `${this.runtimeUrl}/agent/${encodeURIComponent(this.agentId)}/stop/${encodeURIComponent(this.threadId)}`; + const origin = typeof window !== "undefined" && window.location ? window.location.origin : "http://localhost"; + const base = new URL(this.runtimeUrl, origin); + const stopUrl = new URL(stopPath, base); + + void fetch(stopUrl.toString(), { + method: "POST", + headers: { + "Content-Type": "application/json", + ...this.headers, + }, + }).catch((error) => { + console.error("ProxiedCopilotRuntimeAgent: stop request failed", error); + }); + } + connect(input: RunAgentInput): Observable { - const httpEvents = runHttpRequest( - `${this.runtimeUrl}/agent/${this.agentId}/connect`, - this.requestInit(input) - ); + const httpEvents = runHttpRequest(`${this.runtimeUrl}/agent/${this.agentId}/connect`, this.requestInit(input)); return transformHttpEventStream(httpEvents); } } diff --git a/packages/core/src/core/core.ts b/packages/core/src/core/core.ts index e02c465b..dad2bcc6 100644 --- a/packages/core/src/core/core.ts +++ b/packages/core/src/core/core.ts @@ -28,7 +28,15 @@ export interface CopilotKitCoreConfig { } export type { CopilotKitCoreAddAgentParams }; -export type { CopilotKitCoreRunAgentParams, CopilotKitCoreConnectAgentParams, CopilotKitCoreGetToolParams }; +export type { + CopilotKitCoreRunAgentParams, + CopilotKitCoreConnectAgentParams, + CopilotKitCoreGetToolParams, +}; + +export interface CopilotKitCoreStopAgentParams { + agent: AbstractAgent; +} export type CopilotKitCoreGetSuggestionsResult = { suggestions: Suggestion[]; @@ -395,6 +403,10 @@ export class CopilotKitCore { return this.runHandler.connectAgent(params); } + stopAgent(params: CopilotKitCoreStopAgentParams): void { + params.agent.abortRun(); + } + async runAgent(params: CopilotKitCoreRunAgentParams): Promise { return this.runHandler.runAgent(params); } diff --git a/packages/react/src/components/chat/CopilotChat.tsx b/packages/react/src/components/chat/CopilotChat.tsx index 37c6588e..b91eae74 100644 --- a/packages/react/src/components/chat/CopilotChat.tsx +++ b/packages/react/src/components/chat/CopilotChat.tsx @@ -1,6 +1,7 @@ import { useAgent } from "@/hooks/use-agent"; import { useSuggestions } from "@/hooks/use-suggestions"; import { CopilotChatView, CopilotChatViewProps } from "./CopilotChatView"; +import CopilotChatInput, { CopilotChatInputProps } from "./CopilotChatInput"; import { CopilotChatConfigurationProvider, CopilotChatLabels, @@ -104,6 +105,23 @@ export function CopilotChat({ agentId, threadId, labels, chatView, isModalDefaul [agent, copilotkit], ); + const stopCurrentRun = useCallback(() => { + if (!agent) { + return; + } + + try { + copilotkit.stopAgent({ agent }); + } catch (error) { + console.error("CopilotChat: stopAgent failed", error); + try { + agent.abortRun(); + } catch (abortError) { + console.error("CopilotChat: abortRun fallback failed", abortError); + } + } + }, [agent, copilotkit]); + const mergedProps = merge( { isRunning: agent?.isRunning ?? false, @@ -121,12 +139,23 @@ export function CopilotChat({ agentId, threadId, labels, chatView, isModalDefaul }, ); + const providedStopHandler = providedInputProps?.onStop; + const hasMessages = (agent?.messages?.length ?? 0) > 0; + const shouldAllowStop = (agent?.isRunning ?? false) && hasMessages; + const effectiveStopHandler = shouldAllowStop ? providedStopHandler ?? stopCurrentRun : providedStopHandler; + + const finalInputProps = { + ...providedInputProps, + onSubmitMessage: onSubmitInput, + onStop: effectiveStopHandler, + isRunning: agent?.isRunning ?? false, + } as Partial & { onSubmitMessage: (value: string) => void }; + + finalInputProps.mode = agent?.isRunning ? "processing" : finalInputProps.mode ?? "input"; + const finalProps = merge(mergedProps, { messages: agent?.messages ?? [], - inputProps: { - onSubmitMessage: onSubmitInput, - ...providedInputProps, - }, + inputProps: finalInputProps, }) as CopilotChatViewProps; // Always create a provider with merged values diff --git a/packages/react/src/components/chat/CopilotChatInput.tsx b/packages/react/src/components/chat/CopilotChatInput.tsx index 56b6d79a..065316aa 100644 --- a/packages/react/src/components/chat/CopilotChatInput.tsx +++ b/packages/react/src/components/chat/CopilotChatInput.tsx @@ -11,7 +11,7 @@ import React, { useMemo, } from "react"; import { twMerge } from "tailwind-merge"; -import { Plus, Mic, ArrowUp, X, Check } from "lucide-react"; +import { Plus, Mic, ArrowUp, X, Check, Square } from "lucide-react"; import { CopilotChatLabels, @@ -64,6 +64,8 @@ type CopilotChatInputRestProps = { toolsMenu?: (ToolsMenuItem | "-")[]; autoFocus?: boolean; onSubmitMessage?: (value: string) => void; + onStop?: () => void; + isRunning?: boolean; onStartTranscribe?: () => void; onCancelTranscribe?: () => void; onFinishTranscribe?: () => void; @@ -90,6 +92,8 @@ const SLASH_MENU_ITEM_HEIGHT_PX = 40; export function CopilotChatInput({ mode = "input", onSubmitMessage, + onStop, + isRunning = false, onStartTranscribe, onCancelTranscribe, onFinishTranscribe, @@ -390,7 +394,11 @@ export function CopilotChatInput({ if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); - send(); + if (isProcessing) { + onStop?.(); + } else { + send(); + } } }; @@ -427,13 +435,26 @@ export function CopilotChatInput({ ), }); + const isProcessing = mode !== "transcribe" && isRunning; + const canSend = resolvedValue.trim().length > 0 && !!onSubmitMessage; + const canStop = !!onStop; + + const handleSendButtonClick = () => { + if (isProcessing) { + onStop?.(); + return; + } + send(); + }; + const BoundAudioRecorder = renderSlot(audioRecorder, CopilotChatAudioRecorder, { ref: audioRecorderRef, }); const BoundSendButton = renderSlot(sendButton, CopilotChatInput.SendButton, { - onClick: send, - disabled: !resolvedValue.trim() || !onSubmitMessage, + onClick: handleSendButtonClick, + disabled: isProcessing ? !canStop : !canSend, + children: isProcessing && canStop ? : undefined, }); const BoundStartTranscribeButton = renderSlot(startTranscribeButton, CopilotChatInput.StartTranscribeButton, { @@ -464,6 +485,8 @@ export function CopilotChatInput({ finishTranscribeButton: BoundFinishTranscribeButton, addMenuButton: BoundAddMenuButton, onSubmitMessage, + onStop, + isRunning, onStartTranscribe, onCancelTranscribe, onFinishTranscribe, @@ -833,7 +856,7 @@ export function CopilotChatInput({ // eslint-disable-next-line @typescript-eslint/no-namespace export namespace CopilotChatInput { - export const SendButton: React.FC> = ({ className, ...props }) => ( + export const SendButton: React.FC> = ({ className, children, ...props }) => (
); diff --git a/packages/runtime/src/__tests__/in-process-agent-runner-messages.test.ts b/packages/runtime/src/__tests__/in-process-agent-runner-messages.test.ts index 073713d8..16ea2e6f 100644 --- a/packages/runtime/src/__tests__/in-process-agent-runner-messages.test.ts +++ b/packages/runtime/src/__tests__/in-process-agent-runner-messages.test.ts @@ -96,18 +96,24 @@ describe("InMemoryAgentRunner – run started inputs", () => { runner.run({ threadId, agent, input }).pipe(toArray()), ); - expect(runEvents).toHaveLength(1); + expect(runEvents[0].type).toBe(EventType.RUN_STARTED); const runStarted = runEvents[0] as RunStartedEvent; - expect(runStarted.type).toBe(EventType.RUN_STARTED); expect(runStarted.input?.messages).toEqual(messages); + const terminalTypes = runEvents.slice(1).map((event) => event.type); + expect(terminalTypes.every((type) => type === EventType.RUN_ERROR || type === EventType.RUN_FINISHED)).toBe(true); + const connectEvents = await firstValueFrom( runner.connect({ threadId }).pipe(toArray()), ); - expect(connectEvents).toHaveLength(1); + expect(connectEvents[0].type).toBe(EventType.RUN_STARTED); const connectRunStarted = connectEvents[0] as RunStartedEvent; expect(connectRunStarted.input?.messages).toEqual(messages); + const connectTerminalTypes = connectEvents.slice(1).map((event) => event.type); + expect( + connectTerminalTypes.every((type) => type === EventType.RUN_ERROR || type === EventType.RUN_FINISHED), + ).toBe(true); }); it("only includes new messages on subsequent runs", async () => { @@ -149,8 +155,13 @@ describe("InMemoryAgentRunner – run started inputs", () => { .pipe(toArray()), ); + expect(secondRunEvents[0].type).toBe(EventType.RUN_STARTED); const runStarted = secondRunEvents[0] as RunStartedEvent; expect(runStarted.input?.messages).toEqual([newMessage]); + const secondTerminalTypes = secondRunEvents.slice(1).map((event) => event.type); + expect( + secondTerminalTypes.every((type) => type === EventType.RUN_ERROR || type === EventType.RUN_FINISHED), + ).toBe(true); const connectEvents = await firstValueFrom( runner.connect({ threadId }).pipe(toArray()), diff --git a/packages/runtime/src/__tests__/in-process-agent-runner.test.ts b/packages/runtime/src/__tests__/in-process-agent-runner.test.ts index eb4cbbda..fd4d0e6b 100644 --- a/packages/runtime/src/__tests__/in-process-agent-runner.test.ts +++ b/packages/runtime/src/__tests__/in-process-agent-runner.test.ts @@ -1,9 +1,14 @@ import { describe, it, expect, beforeEach } from "vitest"; import { InMemoryAgentRunner } from "../runner/in-memory"; -import { AbstractAgent, BaseEvent, RunAgentInput } from "@ag-ui/client"; +import { AbstractAgent, BaseEvent, EventType, RunAgentInput } from "@ag-ui/client"; import { firstValueFrom } from "rxjs"; import { toArray } from "rxjs/operators"; +const stripTerminalEvents = (events: BaseEvent[]) => + events.filter( + (event) => event.type !== EventType.RUN_FINISHED && event.type !== EventType.RUN_ERROR, + ); + // Mock agent implementations for testing class MockAgent extends AbstractAgent { private events: BaseEvent[]; @@ -98,6 +103,94 @@ class ErrorThrowingAgent extends AbstractAgent { } } +class StoppableAgent extends AbstractAgent { + private shouldStop = false; + private eventDelay: number; + + constructor(eventDelay = 5) { + super(); + this.eventDelay = eventDelay; + } + + async runAgent( + input: RunAgentInput, + options: { onEvent: (event: { event: BaseEvent }) => void } + ): Promise { + this.shouldStop = false; + let counter = 0; + + while (!this.shouldStop && counter < 10_000) { + await new Promise((resolve) => setTimeout(resolve, this.eventDelay)); + const event: BaseEvent = { + type: "message", + id: `stoppable-${counter}`, + timestamp: new Date().toISOString(), + data: { counter }, + } as BaseEvent; + options.onEvent({ event }); + counter += 1; + } + } + + abortRun(): void { + this.shouldStop = true; + } + + clone(): AbstractAgent { + return new StoppableAgent(this.eventDelay); + } +} + +class OpenEventsAgent extends AbstractAgent { + private shouldStop = false; + + async runAgent( + input: RunAgentInput, + options: { onEvent: (event: { event: BaseEvent }) => void } + ): Promise { + this.shouldStop = false; + const messageId = "open-message"; + const toolCallId = "open-tool"; + + options.onEvent({ + event: { + type: EventType.TEXT_MESSAGE_START, + messageId, + role: "assistant", + } as BaseEvent, + }); + + options.onEvent({ + event: { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId, + delta: "Partial content", + } as BaseEvent, + }); + + options.onEvent({ + event: { + type: EventType.TOOL_CALL_START, + toolCallId, + toolCallName: "testTool", + parentMessageId: messageId, + } as BaseEvent, + }); + + while (!this.shouldStop) { + await new Promise((resolve) => setTimeout(resolve, 5)); + } + } + + abortRun(): void { + this.shouldStop = true; + } + + clone(): AbstractAgent { + return new OpenEventsAgent(); + } +} + class MultiEventAgent extends AbstractAgent { private runId: string; @@ -161,8 +254,8 @@ describe("InMemoryAgentRunner", () => { const runObservable = runner.run({ threadId, agent, input }); const collectedEvents = await firstValueFrom(runObservable.pipe(toArray())); - expect(collectedEvents).toHaveLength(3); - expect(collectedEvents).toEqual(events); + const agentEvents = stripTerminalEvents(collectedEvents); + expect(agentEvents).toEqual(events); }); it("should allow connecting after run completes and receive all past events", async () => { @@ -188,8 +281,8 @@ describe("InMemoryAgentRunner", () => { const connectObservable = runner.connect({ threadId }); const collectedEvents = await firstValueFrom(connectObservable.pipe(toArray())); - expect(collectedEvents).toHaveLength(2); - expect(collectedEvents).toEqual(events); + const storedAgentEvents = stripTerminalEvents(collectedEvents); + expect(storedAgentEvents).toEqual(events); }); }); @@ -237,22 +330,23 @@ describe("InMemoryAgentRunner", () => { const connectObservable = runner.connect({ threadId }); const allEvents = await firstValueFrom(connectObservable.pipe(toArray())); - expect(allEvents).toHaveLength(15); // 5 events per run × 3 runs - + const agentEvents = stripTerminalEvents(allEvents); + expect(agentEvents).toHaveLength(15); // 5 events per run × 3 runs + // Verify events from all runs are present - const run1Events = allEvents.filter(e => e.id?.startsWith("run-1")); - const run2Events = allEvents.filter(e => e.id?.startsWith("run-2")); - const run3Events = allEvents.filter(e => e.id?.startsWith("run-3")); + const run1Events = agentEvents.filter((e) => e.id?.startsWith("run-1")); + const run2Events = agentEvents.filter((e) => e.id?.startsWith("run-2")); + const run3Events = agentEvents.filter((e) => e.id?.startsWith("run-3")); expect(run1Events).toHaveLength(5); expect(run2Events).toHaveLength(5); expect(run3Events).toHaveLength(5); // Verify order preservation - const runOrder = allEvents.map(e => e.id?.split("-")[0] + "-" + e.id?.split("-")[1]); - expect(runOrder.slice(0, 5).every(id => id.startsWith("run-1"))).toBe(true); - expect(runOrder.slice(5, 10).every(id => id.startsWith("run-2"))).toBe(true); - expect(runOrder.slice(10, 15).every(id => id.startsWith("run-3"))).toBe(true); + const runOrder = agentEvents.map((e) => e.id?.split("-")[0] + "-" + e.id?.split("-")[1]); + expect(runOrder.slice(0, 5).every((id) => id?.startsWith("run-1"))).toBe(true); + expect(runOrder.slice(5, 10).every((id) => id?.startsWith("run-2"))).toBe(true); + expect(runOrder.slice(10, 15).every((id) => id?.startsWith("run-3"))).toBe(true); }); it("should handle connect during multiple runs", async () => { @@ -283,8 +377,9 @@ describe("InMemoryAgentRunner", () => { const allEvents = await eventCollector; // Connect only receives events from the first run since it completes - expect(allEvents).toHaveLength(5); // Only events from first run - const firstRunEvents = allEvents.filter(e => e.id?.startsWith("first")); + const firstRunAgentEvents = stripTerminalEvents(allEvents); + expect(firstRunAgentEvents).toHaveLength(5); + const firstRunEvents = firstRunAgentEvents.filter((e) => e.id?.startsWith("first")); expect(firstRunEvents).toHaveLength(5); @@ -302,7 +397,7 @@ describe("InMemoryAgentRunner", () => { // Connect after both runs to verify all events are accumulated const allEventsAfter = await firstValueFrom(runner.connect({ threadId }).pipe(toArray())); - expect(allEventsAfter).toHaveLength(8); // 5 from first + 3 from second + expect(stripTerminalEvents(allEventsAfter)).toHaveLength(8); // 5 from first + 3 from second }); it("should preserve event order across different agent types", async () => { @@ -334,13 +429,14 @@ describe("InMemoryAgentRunner", () => { const connectObservable = runner.connect({ threadId }); const allEvents = await firstValueFrom(connectObservable.pipe(toArray())); - expect(allEvents).toHaveLength(9); // 2 + 5 + 2 - + const agentEvents = stripTerminalEvents(allEvents); + expect(agentEvents).toHaveLength(9); // 2 + 5 + 2 + // Verify event groups are in order - expect(allEvents[0].id).toBe("mock-1"); - expect(allEvents[1].id).toBe("mock-2"); - expect(allEvents[2].id).toContain("multi"); - expect(allEvents[7].id).toContain("delayed"); + expect(agentEvents[0].id).toBe("mock-1"); + expect(agentEvents[1].id).toBe("mock-2"); + expect(agentEvents[2].id).toContain("multi"); + expect(agentEvents[7].id).toContain("delayed"); }); }); @@ -370,10 +466,13 @@ describe("InMemoryAgentRunner", () => { firstValueFrom(connect3.pipe(toArray())), ]); - // All should receive same events - expect(events1).toHaveLength(5); - expect(events2).toHaveLength(5); - expect(events3).toHaveLength(5); + // All should receive same events including RUN_FINISHED + const agentEvents1 = stripTerminalEvents(events1); + const agentEvents2 = stripTerminalEvents(events2); + const agentEvents3 = stripTerminalEvents(events3); + expect(agentEvents1).toHaveLength(5); + expect(agentEvents2).toHaveLength(5); + expect(agentEvents3).toHaveLength(5); expect(events1).toEqual(events2); expect(events2).toEqual(events3); }); @@ -407,10 +506,13 @@ describe("InMemoryAgentRunner", () => { firstValueFrom(connect3.pipe(toArray())), ]); - // All subscribers should eventually receive all events - expect(events1).toHaveLength(10); - expect(events2).toHaveLength(10); - expect(events3).toHaveLength(10); + // All subscribers should eventually receive all events plus RUN_FINISHED + const agentEvents1 = stripTerminalEvents(events1); + const agentEvents2 = stripTerminalEvents(events2); + const agentEvents3 = stripTerminalEvents(events3); + expect(agentEvents1).toHaveLength(10); + expect(agentEvents2).toHaveLength(10); + expect(agentEvents3).toHaveLength(10); // Verify they all have the same events expect(events1.map(e => e.id)).toEqual(events2.map(e => e.id)); @@ -452,8 +554,10 @@ describe("InMemoryAgentRunner", () => { const events = await firstValueFrom(runObservable.pipe(toArray())); // Should still receive events emitted before error - expect(events).toHaveLength(3); - expect(events.every(e => e.id?.startsWith("error-agent"))).toBe(true); + expect(events.at(-1)?.type).toBe(EventType.RUN_ERROR); + const preErrorEvents = events.slice(0, -1); + expect(preErrorEvents).toHaveLength(3); + expect(preErrorEvents.every((e) => e.id?.startsWith("error-agent"))).toBe(true); // Should be able to run again after error const agent2 = new MockAgent([ @@ -470,12 +574,15 @@ describe("InMemoryAgentRunner", () => { const run2 = runner.run({ threadId, agent: agent2, input: input2 }); const events2 = await firstValueFrom(run2.pipe(toArray())); - expect(events2).toHaveLength(1); // Only events from current run - expect(events2[0].id).toBe("recovery-1"); + const recoveryEvents = stripTerminalEvents(events2); + expect(recoveryEvents).toHaveLength(1); // Only events from current run + expect(recoveryEvents[0].id).toBe("recovery-1"); // Connect should have all events including from errored run const allEvents = await firstValueFrom(runner.connect({ threadId }).pipe(toArray())); - expect(allEvents).toHaveLength(4); // 3 from error run + 1 from recovery + expect(allEvents.filter((event) => event.type === EventType.RUN_ERROR).length).toBeGreaterThanOrEqual(1); + const storedAgentEvents = stripTerminalEvents(allEvents); + expect(storedAgentEvents).toHaveLength(4); // 3 from error run + 1 from recovery }); it("should properly set isRunning to false after agent error", async () => { @@ -566,9 +673,10 @@ describe("InMemoryAgentRunner", () => { const connectObservable = runner.connect({ threadId }); const collectedEvents = await firstValueFrom(connectObservable.pipe(toArray())); - expect(collectedEvents).toHaveLength(eventCount); - expect(collectedEvents[0].id).toBe("bulk-0"); - expect(collectedEvents[eventCount - 1].id).toBe(`bulk-${eventCount - 1}`); + const bulkEvents = stripTerminalEvents(collectedEvents); + expect(bulkEvents).toHaveLength(eventCount); + expect(bulkEvents[0].id).toBe("bulk-0"); + expect(bulkEvents[eventCount - 1].id).toBe(`bulk-${eventCount - 1}`); }); it("should return false for isRunning on non-existent thread", async () => { @@ -602,10 +710,60 @@ describe("InMemoryAgentRunner", () => { expect(await runner.isRunning({ threadId })).toBe(false); }); - it("should throw error for stop method (not implemented)", async () => { - expect(() => { - runner.stop({ threadId: "any-thread" }); - }).toThrow("Method not implemented"); + it("should return false when stopping a non-existent thread", async () => { + await expect(runner.stop({ threadId: "missing-thread" })).resolves.toBe(false); + }); + + it("should stop an active run and complete streams", async () => { + const threadId = "test-thread-stop"; + const agent = new StoppableAgent(2); + const input: RunAgentInput = { + messages: [], + state: {}, + threadId, + runId: "run-stop", + }; + + const run$ = runner.run({ threadId, agent, input }); + const collected = firstValueFrom(run$.pipe(toArray())); + + // Allow the run loop to start and emit a couple of events + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(await runner.isRunning({ threadId })).toBe(true); + + const stopped = await runner.stop({ threadId }); + expect(stopped).toBe(true); + + const events = await collected; + expect(events.length).toBeGreaterThan(0); + expect(events[events.length - 1].type).toBe(EventType.RUN_FINISHED); + expect(await runner.isRunning({ threadId })).toBe(false); + }); + + it("should close open text and tool events when stopping", async () => { + const threadId = "test-thread-open-events"; + const agent = new OpenEventsAgent(); + const input: RunAgentInput = { + messages: [], + state: {}, + threadId, + runId: "run-open", + }; + + const run$ = runner.run({ threadId, agent, input }); + const collected = firstValueFrom(run$.pipe(toArray())); + + await new Promise((resolve) => setTimeout(resolve, 20)); + await runner.stop({ threadId }); + + const events = await collected; + const endingTypes = events.slice(-4).map((event) => event.type); + expect(endingTypes).toEqual([ + EventType.TEXT_MESSAGE_END, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + EventType.RUN_FINISHED, + ]); }); it("should handle thread isolation correctly", async () => { @@ -641,11 +799,13 @@ describe("InMemoryAgentRunner", () => { const events2 = await firstValueFrom(runner.connect({ threadId: thread2 }).pipe(toArray())); // Verify isolation - expect(events1).toHaveLength(1); - expect(events1[0].id).toBe("t1-event"); + const thread1Events = stripTerminalEvents(events1); + const thread2Events = stripTerminalEvents(events2); + expect(thread1Events).toHaveLength(1); + expect(thread1Events[0].id).toBe("t1-event"); - expect(events2).toHaveLength(1); - expect(events2[0].id).toBe("t2-event"); + expect(thread2Events).toHaveLength(1); + expect(thread2Events[0].id).toBe("t2-event"); }); }); @@ -673,15 +833,16 @@ describe("InMemoryAgentRunner", () => { } const allEvents = await firstValueFrom(runner.connect({ threadId }).pipe(toArray())); - - expect(allEvents).toHaveLength(12); // 1 + 3 + 1 + 5 + 2 - + + const agentEvents = stripTerminalEvents(allEvents); + expect(agentEvents).toHaveLength(12); // 1 + 3 + 1 + 5 + 2 + // Verify event ordering - expect(allEvents[0].id).toBe("instant-1"); - expect(allEvents[1].id).toContain("delayed-0"); - expect(allEvents[4].id).toBe("instant-2"); - expect(allEvents[5].id).toContain("multi-start"); - expect(allEvents[10].id).toContain("slow-0"); + expect(agentEvents[0].id).toBe("instant-1"); + expect(agentEvents[1].id).toContain("delayed-0"); + expect(agentEvents[4].id).toBe("instant-2"); + expect(agentEvents[5].id).toContain("multi-start"); + expect(agentEvents[10].id).toContain("slow-0"); }); it("should handle subscriber that connects between runs", async () => { @@ -700,8 +861,9 @@ describe("InMemoryAgentRunner", () => { const midConnectObservable = runner.connect({ threadId }); const midEvents = await firstValueFrom(midConnectObservable.pipe(toArray())); - expect(midEvents).toHaveLength(5); // Only events from first run - const firstRunEvents = midEvents.filter(e => e.id?.includes("first")); + const midAgentEvents = stripTerminalEvents(midEvents); + expect(midAgentEvents).toHaveLength(5); // Only events from first run + const firstRunEvents = midAgentEvents.filter((e) => e.id?.includes("first")); expect(firstRunEvents).toHaveLength(5); // Second run @@ -715,12 +877,13 @@ describe("InMemoryAgentRunner", () => { // Connect after both runs to verify all events const allEvents = await firstValueFrom(runner.connect({ threadId }).pipe(toArray())); - expect(allEvents).toHaveLength(10); // Events from both runs - - const allFirstRunEvents = allEvents.filter(e => e.id?.includes("first")); - const allSecondRunEvents = allEvents.filter(e => e.id?.includes("second")); + const allAgentEvents = stripTerminalEvents(allEvents); + expect(allAgentEvents).toHaveLength(10); // Events from both runs + + const allFirstRunEvents = allAgentEvents.filter((e) => e.id?.includes("first")); + const allSecondRunEvents = allAgentEvents.filter((e) => e.id?.includes("second")); expect(allFirstRunEvents).toHaveLength(5); expect(allSecondRunEvents).toHaveLength(5); }); }); -}); \ No newline at end of file +}); diff --git a/packages/runtime/src/endpoint.ts b/packages/runtime/src/endpoint.ts index d0f456de..b283f5d3 100644 --- a/packages/runtime/src/endpoint.ts +++ b/packages/runtime/src/endpoint.ts @@ -5,11 +5,9 @@ import { handleRunAgent } from "./handlers/handle-run"; import { handleGetRuntimeInfo } from "./handlers/get-runtime-info"; import { handleTranscribe } from "./handlers/handle-transcribe"; import { logger } from "@copilotkitnext/shared"; -import { - callBeforeRequestMiddleware, - callAfterRequestMiddleware, -} from "./middleware"; +import { callBeforeRequestMiddleware, callAfterRequestMiddleware } from "./middleware"; import { handleConnectAgent } from "./handlers/handle-connect"; +import { handleStopAgent } from "./handlers/handle-stop"; interface CopilotEndpointParams { runtime: CopilotRuntime; @@ -23,10 +21,7 @@ type CopilotEndpointContext = { }; }; -export function createCopilotEndpoint({ - runtime, - basePath, -}: CopilotEndpointParams) { +export function createCopilotEndpoint({ runtime, basePath }: CopilotEndpointParams) { const app = new Hono(); return app @@ -37,7 +32,7 @@ export function createCopilotEndpoint({ origin: "*", allowMethods: ["GET", "HEAD", "PUT", "POST", "DELETE", "PATCH", "OPTIONS"], allowHeaders: ["*"], - }) + }), ) .use("*", async (c, next) => { const request = c.req.raw; @@ -53,10 +48,7 @@ export function createCopilotEndpoint({ c.set("modifiedRequest", maybeModifiedRequest); } } catch (error) { - logger.error( - { err: error, url: request.url, path }, - "Error running before request middleware" - ); + logger.error({ err: error, url: request.url, path }, "Error running before request middleware"); if (error instanceof Response) { return error; } @@ -77,10 +69,7 @@ export function createCopilotEndpoint({ response, path, }).catch((error) => { - logger.error( - { err: error, url: c.req.url, path }, - "Error running after request middleware" - ); + logger.error({ err: error, url: c.req.url, path }, "Error running after request middleware"); }); }) .post("/agent/:agentId/run", async (c) => { @@ -94,10 +83,7 @@ export function createCopilotEndpoint({ agentId, }); } catch (error) { - logger.error( - { err: error, url: request.url, path: c.req.path }, - "Error running request handler" - ); + logger.error({ err: error, url: request.url, path: c.req.path }, "Error running request handler"); throw error; } }) @@ -112,10 +98,25 @@ export function createCopilotEndpoint({ agentId, }); } catch (error) { - logger.error( - { err: error, url: request.url, path: c.req.path }, - "Error running request handler" - ); + logger.error({ err: error, url: request.url, path: c.req.path }, "Error running request handler"); + throw error; + } + }) + + .post("/agent/:agentId/stop/:threadId", async (c) => { + const agentId = c.req.param("agentId"); + const threadId = c.req.param("threadId"); + const request = c.get("modifiedRequest") || c.req.raw; + + try { + return await handleStopAgent({ + runtime, + request, + agentId, + threadId, + }); + } catch (error) { + logger.error({ err: error, url: request.url, path: c.req.path }, "Error running request handler"); throw error; } }) @@ -128,10 +129,7 @@ export function createCopilotEndpoint({ request, }); } catch (error) { - logger.error( - { err: error, url: request.url, path: c.req.path }, - "Error running request handler" - ); + logger.error({ err: error, url: request.url, path: c.req.path }, "Error running request handler"); throw error; } }) @@ -144,10 +142,7 @@ export function createCopilotEndpoint({ request, }); } catch (error) { - logger.error( - { err: error, url: request.url, path: c.req.path }, - "Error running request handler" - ); + logger.error({ err: error, url: request.url, path: c.req.path }, "Error running request handler"); throw error; } }) diff --git a/packages/runtime/src/handlers/handle-stop.ts b/packages/runtime/src/handlers/handle-stop.ts new file mode 100644 index 00000000..4a6bd507 --- /dev/null +++ b/packages/runtime/src/handlers/handle-stop.ts @@ -0,0 +1,76 @@ +import { CopilotRuntime } from "../runtime"; +import { EventType } from "@ag-ui/client"; + +interface StopAgentParameters { + request: Request; + runtime: CopilotRuntime; + agentId: string; + threadId: string; +} + +export async function handleStopAgent({ + runtime, + request, + agentId, + threadId, +}: StopAgentParameters) { + try { + const agents = await runtime.agents; + + if (!agents[agentId]) { + return new Response( + JSON.stringify({ + error: "Agent not found", + message: `Agent '${agentId}' does not exist`, + }), + { + status: 404, + headers: { "Content-Type": "application/json" }, + } + ); + } + + const stopped = await runtime.runner.stop({ threadId }); + + if (!stopped) { + return new Response( + JSON.stringify({ + stopped: false, + message: `No active run for thread '${threadId}'.`, + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + } + ); + } + + return new Response( + JSON.stringify({ + stopped: true, + interrupt: { + type: EventType.RUN_ERROR, + message: "Run stopped by user", + code: "STOPPED", + }, + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + } + ); + } catch (error) { + console.error("Error stopping agent run:", error); + + return new Response( + JSON.stringify({ + error: "Failed to stop agent", + message: error instanceof Error ? error.message : "Unknown error", + }), + { + status: 500, + headers: { "Content-Type": "application/json" }, + } + ); + } +} diff --git a/packages/runtime/src/runner/__tests__/finalize-events.test.ts b/packages/runtime/src/runner/__tests__/finalize-events.test.ts new file mode 100644 index 00000000..256772f2 --- /dev/null +++ b/packages/runtime/src/runner/__tests__/finalize-events.test.ts @@ -0,0 +1,99 @@ +import { describe, expect, it } from "vitest"; +import { + BaseEvent, + EventType, + ToolCallResultEvent, + RunErrorEvent, +} from "@ag-ui/client"; +import { finalizeRunEvents } from "../finalize-events"; + +const createTextStart = (messageId: string): BaseEvent => ({ + type: EventType.TEXT_MESSAGE_START, + messageId, +} as BaseEvent); + +const createToolStart = (toolCallId: string): BaseEvent => ({ + type: EventType.TOOL_CALL_START, + toolCallId, +} as BaseEvent); + +describe("finalizeRunEvents", () => { + it("closes streams with a RUN_FINISHED event when a stop was requested", () => { + const events: BaseEvent[] = [ + createTextStart("msg-1"), + createToolStart("tool-1"), + ]; + + const appended = finalizeRunEvents(events, { stopRequested: true }); + + expect(appended.map((event) => event.type)).toEqual([ + EventType.TEXT_MESSAGE_END, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + EventType.RUN_FINISHED, + ]); + + const resultEvent = appended.find( + (event): event is ToolCallResultEvent => event.type === EventType.TOOL_CALL_RESULT, + ); + expect(JSON.parse(resultEvent?.content ?? "")).toEqual( + expect.objectContaining({ + status: "stopped", + reason: "stop_requested", + }), + ); + + expect(events.at(-1)?.type).toBe(EventType.RUN_FINISHED); + }); + + it("emits a RUN_ERROR with meaningful payload when the stream ends abruptly", () => { + const events: BaseEvent[] = [ + createTextStart("msg-1"), + createToolStart("tool-1"), + ]; + + const appended = finalizeRunEvents(events); + + expect(appended.map((event) => event.type)).toEqual([ + EventType.TEXT_MESSAGE_END, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + EventType.RUN_ERROR, + ]); + + const resultEvent = appended.find( + (event): event is ToolCallResultEvent => event.type === EventType.TOOL_CALL_RESULT, + ); + expect(JSON.parse(resultEvent?.content ?? "")).toEqual( + expect.objectContaining({ + status: "error", + reason: "missing_terminal_event", + }), + ); + + const errorEvent = appended.find( + (event): event is RunErrorEvent => event.type === EventType.RUN_ERROR, + ); + expect(errorEvent?.code).toBe("INCOMPLETE_STREAM"); + expect(errorEvent?.message).toContain("terminal event"); + }); + + it("only appends structural fixes when a terminal event already exists", () => { + const events: BaseEvent[] = [ + createTextStart("msg-1"), + createToolStart("tool-1"), + { type: EventType.RUN_FINISHED } as BaseEvent, + ]; + + const appended = finalizeRunEvents(events); + + expect(appended.map((event) => event.type)).toEqual([ + EventType.TEXT_MESSAGE_END, + EventType.TOOL_CALL_END, + ]); + + expect(appended.some((event) => event.type === EventType.TOOL_CALL_RESULT)).toBe(false); + expect(appended.some((event) => event.type === EventType.RUN_ERROR)).toBe(false); + expect(appended.some((event) => event.type === EventType.RUN_FINISHED)).toBe(false); + }); +}); diff --git a/packages/runtime/src/runner/__tests__/in-memory-runner.test.ts b/packages/runtime/src/runner/__tests__/in-memory-runner.test.ts index d63d302f..f6355bd7 100644 --- a/packages/runtime/src/runner/__tests__/in-memory-runner.test.ts +++ b/packages/runtime/src/runner/__tests__/in-memory-runner.test.ts @@ -15,6 +15,11 @@ import { import { EMPTY, firstValueFrom } from "rxjs"; import { toArray } from "rxjs/operators"; +const stripTerminalEvents = (events: BaseEvent[]) => + events.filter( + (event) => event.type !== EventType.RUN_FINISHED && event.type !== EventType.RUN_ERROR, + ); + class TestAgent extends AbstractAgent { constructor( private readonly events: BaseEvent[] = [], @@ -96,9 +101,10 @@ describe("InMemoryAgentRunner", () => { .pipe(toArray()), ); - expect(events).toHaveLength(4); - expect(events[0].type).toBe(EventType.RUN_STARTED); - const compacted = events.slice(1); + const nonTerminalEvents = stripTerminalEvents(events); + expect(nonTerminalEvents).toHaveLength(4); + expect(nonTerminalEvents[0].type).toBe(EventType.RUN_STARTED); + const compacted = nonTerminalEvents.slice(1); expect(compacted[0].type).toBe(EventType.TEXT_MESSAGE_START); expect(compacted[1].type).toBe(EventType.TEXT_MESSAGE_CONTENT); expect((compacted[1] as TextMessageContentEvent).delta).toBe("Hello"); @@ -185,8 +191,9 @@ describe("InMemoryAgentRunner", () => { .pipe(toArray()), ); - expect(events).toHaveLength(1); - const runStarted = events[0] as RunStartedEvent; + const nonTerminalEvents = stripTerminalEvents(events); + expect(nonTerminalEvents).toHaveLength(1); + const runStarted = nonTerminalEvents[0] as RunStartedEvent; expect(runStarted.input).toBe(providedInput); }); }); @@ -225,9 +232,10 @@ describe("InMemoryAgentRunner", () => { runner.connect({ threadId }).pipe(toArray()), ); - expect(connectEvents).toHaveLength(4); - expect(connectEvents[0].type).toBe(EventType.RUN_STARTED); - expect(connectEvents.slice(1).map((event) => event.type)).toEqual([ + const nonTerminalEvents = stripTerminalEvents(connectEvents); + expect(nonTerminalEvents).toHaveLength(4); + expect(nonTerminalEvents[0].type).toBe(EventType.RUN_STARTED); + expect(nonTerminalEvents.slice(1).map((event) => event.type)).toEqual([ EventType.TEXT_MESSAGE_START, EventType.TEXT_MESSAGE_CONTENT, EventType.TEXT_MESSAGE_END, @@ -256,8 +264,9 @@ describe("InMemoryAgentRunner", () => { .pipe(toArray()), ); - expect(events).toHaveLength(2); - const [, toolResult] = events; + const nonTerminalEvents = stripTerminalEvents(events); + expect(nonTerminalEvents).toHaveLength(2); + const [, toolResult] = nonTerminalEvents; expect(toolResult.type).toBe(EventType.TOOL_CALL_RESULT); }); }); diff --git a/packages/runtime/src/runner/finalize-events.ts b/packages/runtime/src/runner/finalize-events.ts new file mode 100644 index 00000000..af3f7da9 --- /dev/null +++ b/packages/runtime/src/runner/finalize-events.ts @@ -0,0 +1,154 @@ +import { randomUUID } from "node:crypto"; +import { + BaseEvent, + EventType, + RunErrorEvent, +} from "@ag-ui/client"; + +interface FinalizeRunOptions { + stopRequested?: boolean; + interruptionMessage?: string; +} + +const defaultStopMessage = "Run stopped by user"; +const defaultAbruptEndMessage = "Run ended without emitting a terminal event"; + +export function finalizeRunEvents( + events: BaseEvent[], + options: FinalizeRunOptions = {}, +): BaseEvent[] { + const { stopRequested = false, interruptionMessage } = options; + + const resolvedStopMessage = interruptionMessage ?? defaultStopMessage; + const resolvedAbruptMessage = + interruptionMessage && interruptionMessage !== defaultStopMessage + ? interruptionMessage + : defaultAbruptEndMessage; + + const appended: BaseEvent[] = []; + + const openMessageIds = new Set(); + const openToolCalls = new Map< + string, + { + hasEnd: boolean; + hasResult: boolean; + } + >(); + + for (const event of events) { + switch (event.type) { + case EventType.TEXT_MESSAGE_START: { + const messageId = (event as { messageId?: string }).messageId; + if (typeof messageId === "string") { + openMessageIds.add(messageId); + } + break; + } + case EventType.TEXT_MESSAGE_END: { + const messageId = (event as { messageId?: string }).messageId; + if (typeof messageId === "string") { + openMessageIds.delete(messageId); + } + break; + } + case EventType.TOOL_CALL_START: { + const toolCallId = (event as { toolCallId?: string }).toolCallId; + if (typeof toolCallId === "string") { + openToolCalls.set(toolCallId, { + hasEnd: false, + hasResult: false, + }); + } + break; + } + case EventType.TOOL_CALL_END: { + const toolCallId = (event as { toolCallId?: string }).toolCallId; + const info = toolCallId ? openToolCalls.get(toolCallId) : undefined; + if (info) { + info.hasEnd = true; + } + break; + } + case EventType.TOOL_CALL_RESULT: { + const toolCallId = (event as { toolCallId?: string }).toolCallId; + const info = toolCallId ? openToolCalls.get(toolCallId) : undefined; + if (info) { + info.hasResult = true; + } + break; + } + default: + break; + } + } + + const hasRunFinished = events.some((event) => event.type === EventType.RUN_FINISHED); + const hasRunError = events.some((event) => event.type === EventType.RUN_ERROR); + const hasTerminalEvent = hasRunFinished || hasRunError; + const terminalEventMissing = !hasTerminalEvent; + + for (const messageId of openMessageIds) { + const endEvent = { + type: EventType.TEXT_MESSAGE_END, + messageId, + } as BaseEvent; + events.push(endEvent); + appended.push(endEvent); + } + + for (const [toolCallId, info] of openToolCalls) { + if (!info.hasEnd) { + const endEvent = { + type: EventType.TOOL_CALL_END, + toolCallId, + } as BaseEvent; + events.push(endEvent); + appended.push(endEvent); + } + + if (terminalEventMissing && !info.hasResult) { + const resultEvent = { + type: EventType.TOOL_CALL_RESULT, + toolCallId, + messageId: `${toolCallId ?? randomUUID()}-result`, + role: "tool", + content: JSON.stringify( + stopRequested + ? { + status: "stopped", + reason: "stop_requested", + message: resolvedStopMessage, + } + : { + status: "error", + reason: "missing_terminal_event", + message: resolvedAbruptMessage, + }, + ), + } as BaseEvent; + events.push(resultEvent); + appended.push(resultEvent); + } + } + + if (terminalEventMissing) { + if (stopRequested) { + const finishedEvent = { + type: EventType.RUN_FINISHED, + } as BaseEvent; + events.push(finishedEvent); + appended.push(finishedEvent); + } else { + const errorEvent: RunErrorEvent = { + type: EventType.RUN_ERROR, + message: resolvedAbruptMessage, + code: "INCOMPLETE_STREAM", + }; + events.push(errorEvent); + appended.push(errorEvent); + } + } + + return appended; +} diff --git a/packages/runtime/src/runner/in-memory.ts b/packages/runtime/src/runner/in-memory.ts index 6625bc1a..2c8f02b4 100644 --- a/packages/runtime/src/runner/in-memory.ts +++ b/packages/runtime/src/runner/in-memory.ts @@ -7,12 +7,14 @@ import { } from "./agent-runner"; import { Observable, ReplaySubject } from "rxjs"; import { + AbstractAgent, BaseEvent, EventType, MessagesSnapshotEvent, RunStartedEvent, compactEvents, } from "@ag-ui/client"; +import { finalizeRunEvents } from "./finalize-events"; interface HistoricRun { threadId: string; @@ -31,14 +33,23 @@ class InMemoryEventStore { /** True while a run is actively producing events. */ isRunning = false; - /** Lets stop() cancel the current producer. */ - abortController = new AbortController(); - /** Current run ID */ currentRunId: string | null = null; /** Historic completed runs */ historicRuns: HistoricRun[] = []; + + /** Currently running agent instance (if any). */ + agent: AbstractAgent | null = null; + + /** Subject returned from run() while the run is active. */ + runSubject: ReplaySubject | null = null; + + /** True once stop() has been requested but the run has not yet finalized. */ + stopRequested = false; + + /** Reference to the events emitted in the current run. */ + currentEvents: BaseEvent[] | null = null; } const GLOBAL_STORE = new Map(); @@ -57,10 +68,13 @@ export class InMemoryAgentRunner extends AgentRunner { } store.isRunning = true; store.currentRunId = request.input.runId; + store.agent = request.agent; + store.stopRequested = false; // Track seen message IDs and current run events for this run const seenMessageIds = new Set(); const currentRunEvents: BaseEvent[] = []; + store.currentEvents = currentRunEvents; // Get all previously seen message IDs from historic runs const historicMessageIds = new Set(); @@ -84,10 +98,10 @@ export class InMemoryAgentRunner extends AgentRunner { // Update the store's subject immediately store.subject = nextSubject; - store.abortController = new AbortController(); // Create a subject for run() return value const runSubject = new ReplaySubject(Infinity); + store.runSubject = runSubject; // Helper function to run the agent and handle errors const runAgent = async () => { @@ -142,6 +156,14 @@ export class InMemoryAgentRunner extends AgentRunner { }, }); + const appendedEvents = finalizeRunEvents(currentRunEvents, { + stopRequested: store.stopRequested, + }); + for (const event of appendedEvents) { + runSubject.next(event); + nextSubject.next(event); + } + // Store the completed run in memory with ONLY its events if (store.currentRunId) { // Compact the events before storing (like SQLite does) @@ -157,11 +179,23 @@ export class InMemoryAgentRunner extends AgentRunner { } // Complete the run - store.isRunning = false; + store.currentEvents = null; store.currentRunId = null; + store.agent = null; + store.runSubject = null; + store.stopRequested = false; + store.isRunning = false; runSubject.complete(); nextSubject.complete(); } catch { + const appendedEvents = finalizeRunEvents(currentRunEvents, { + stopRequested: store.stopRequested, + }); + for (const event of appendedEvents) { + runSubject.next(event); + nextSubject.next(event); + } + // Store the run even if it failed (partial events) if (store.currentRunId && currentRunEvents.length > 0) { // Compact the events before storing (like SQLite does) @@ -176,8 +210,12 @@ export class InMemoryAgentRunner extends AgentRunner { } // Complete the run - store.isRunning = false; + store.currentEvents = null; store.currentRunId = null; + store.agent = null; + store.runSubject = null; + store.stopRequested = false; + store.isRunning = false; runSubject.complete(); nextSubject.complete(); } @@ -230,7 +268,7 @@ export class InMemoryAgentRunner extends AgentRunner { } // Bridge active run to connection if exists - if (store.subject && store.isRunning) { + if (store.subject && (store.isRunning || store.stopRequested)) { store.subject.subscribe({ next: (event) => { // Skip message events that we've already emitted from historic @@ -259,8 +297,33 @@ export class InMemoryAgentRunner extends AgentRunner { return Promise.resolve(store?.isRunning ?? false); } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - stop(_request: AgentRunnerStopRequest): Promise { - throw new Error("Method not implemented."); + stop(request: AgentRunnerStopRequest): Promise { + const store = GLOBAL_STORE.get(request.threadId); + if (!store || !store.isRunning) { + return Promise.resolve(false); + } + if (store.stopRequested) { + return Promise.resolve(false); + } + + store.stopRequested = true; + store.isRunning = false; + + const agent = store.agent; + if (!agent) { + store.stopRequested = false; + store.isRunning = false; + return Promise.resolve(false); + } + + try { + agent.abortRun(); + return Promise.resolve(true); + } catch (error) { + console.error("Failed to abort agent run", error); + store.stopRequested = false; + store.isRunning = true; + return Promise.resolve(false); + } } } diff --git a/packages/runtime/src/runner/index.ts b/packages/runtime/src/runner/index.ts index 1e5ddc7c..5163c961 100644 --- a/packages/runtime/src/runner/index.ts +++ b/packages/runtime/src/runner/index.ts @@ -1,2 +1,3 @@ export * from "./agent-runner"; export * from "./in-memory"; +export * from "./finalize-events"; diff --git a/packages/sqlite-runner/src/__tests__/sqlite-runner.test.ts b/packages/sqlite-runner/src/__tests__/sqlite-runner.test.ts index dd2f60d6..7f204e9f 100644 --- a/packages/sqlite-runner/src/__tests__/sqlite-runner.test.ts +++ b/packages/sqlite-runner/src/__tests__/sqlite-runner.test.ts @@ -6,6 +6,7 @@ import { EventType, Message, RunAgentInput, + RunFinishedEvent, RunStartedEvent, TextMessageContentEvent, TextMessageEndEvent, @@ -46,6 +47,19 @@ class MockAgent extends AbstractAgent { for (const event of this.events) { await callbacks.onEvent({ event }); } + + const hasTerminalEvent = this.events.some((event) => + event.type === EventType.RUN_FINISHED || event.type === EventType.RUN_ERROR, + ); + + if (!hasTerminalEvent) { + const runFinished: RunFinishedEvent = { + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }; + await callbacks.onEvent({ event: runFinished }); + } } protected run(): ReturnType { @@ -61,6 +75,101 @@ class MockAgent extends AbstractAgent { } } +class StoppableAgent extends AbstractAgent { + private shouldStop = false; + private eventDelay: number; + + constructor(eventDelay = 5) { + super(); + this.eventDelay = eventDelay; + } + + async runAgent( + input: RunAgentInput, + callbacks: RunCallbacks, + ): Promise { + this.shouldStop = false; + let counter = 0; + + const runStarted: RunStartedEvent = { + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }; + await callbacks.onEvent({ event: runStarted }); + await callbacks.onRunStartedEvent?.(); + + while (!this.shouldStop && counter < 10_000) { + await new Promise((resolve) => setTimeout(resolve, this.eventDelay)); + const event: BaseEvent = { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId: `sqlite-stop-${counter}`, + delta: `chunk-${counter}`, + } as TextMessageContentEvent; + await callbacks.onEvent({ event }); + counter += 1; + } + } + + abortRun(): void { + this.shouldStop = true; + } + + clone(): AbstractAgent { + return new StoppableAgent(this.eventDelay); + } +} + +class OpenEventsAgent extends AbstractAgent { + private shouldStop = false; + + async runAgent( + input: RunAgentInput, + callbacks: RunCallbacks, + ): Promise { + this.shouldStop = false; + const messageId = "open-message"; + const toolCallId = "open-tool"; + + await callbacks.onEvent({ + event: { + type: EventType.TEXT_MESSAGE_START, + messageId, + role: "assistant", + } as BaseEvent, + }); + + await callbacks.onEvent({ + event: { + type: EventType.TEXT_MESSAGE_CONTENT, + messageId, + delta: "Partial content", + } as BaseEvent, + }); + + await callbacks.onEvent({ + event: { + type: EventType.TOOL_CALL_START, + toolCallId, + toolCallName: "testTool", + parentMessageId: messageId, + } as BaseEvent, + }); + + while (!this.shouldStop) { + await new Promise((resolve) => setTimeout(resolve, 5)); + } + } + + abortRun(): void { + this.shouldStop = true; + } + + clone(): AbstractAgent { + return new OpenEventsAgent(); + } +} + describe("SqliteAgentRunner", () => { let tempDir: string; let dbPath: string; @@ -83,6 +192,7 @@ describe("SqliteAgentRunner", () => { { type: EventType.TEXT_MESSAGE_START, messageId: "msg-1", role: "assistant" } as TextMessageStartEvent, { type: EventType.TEXT_MESSAGE_CONTENT, messageId: "msg-1", delta: "Hello" } as TextMessageContentEvent, { type: EventType.TEXT_MESSAGE_END, messageId: "msg-1" } as TextMessageEndEvent, + { type: EventType.RUN_FINISHED, threadId, runId: "run-1" } as RunFinishedEvent, ]); const events = await firstValueFrom( @@ -100,6 +210,7 @@ describe("SqliteAgentRunner", () => { EventType.TEXT_MESSAGE_START, EventType.TEXT_MESSAGE_CONTENT, EventType.TEXT_MESSAGE_END, + EventType.RUN_FINISHED, ]); }); @@ -171,6 +282,11 @@ describe("SqliteAgentRunner", () => { runId: "run-keep", input: providedInput, } as RunStartedEvent, + { + type: EventType.RUN_FINISHED, + threadId, + runId: "run-keep", + } as RunFinishedEvent, ], false, ); @@ -190,7 +306,10 @@ describe("SqliteAgentRunner", () => { .pipe(toArray()), ); - expect(events).toHaveLength(1); + expect(events.map((event) => event.type)).toEqual([ + EventType.RUN_STARTED, + EventType.RUN_FINISHED, + ]); const runStarted = events[0] as RunStartedEvent; expect(runStarted.input).toBe(providedInput); }); @@ -201,6 +320,7 @@ describe("SqliteAgentRunner", () => { { type: EventType.TEXT_MESSAGE_START, messageId: "msg", role: "assistant" } as TextMessageStartEvent, { type: EventType.TEXT_MESSAGE_CONTENT, messageId: "msg", delta: "hi" } as TextMessageContentEvent, { type: EventType.TEXT_MESSAGE_END, messageId: "msg" } as TextMessageEndEvent, + { type: EventType.RUN_FINISHED, threadId, runId: "run-1" } as RunFinishedEvent, ]); await firstValueFrom( @@ -221,6 +341,62 @@ describe("SqliteAgentRunner", () => { EventType.TEXT_MESSAGE_START, EventType.TEXT_MESSAGE_CONTENT, EventType.TEXT_MESSAGE_END, + EventType.RUN_FINISHED, + ]); + }); + + it("returns false when stopping a thread that is not running", async () => { + await expect(runner.stop({ threadId: "sqlite-missing" })).resolves.toBe(false); + }); + + it("stops an active run and completes observables", async () => { + const threadId = "sqlite-stop"; + const agent = new StoppableAgent(2); + const input: RunAgentInput = { + threadId, + runId: "sqlite-stop-run", + messages: [], + state: {}, + }; + + const run$ = runner.run({ threadId, agent, input }); + const collected = firstValueFrom(run$.pipe(toArray())); + + await new Promise((resolve) => setTimeout(resolve, 20)); + expect(await runner.isRunning({ threadId })).toBe(true); + + const stopped = await runner.stop({ threadId }); + expect(stopped).toBe(true); + + const events = await collected; + expect(events.length).toBeGreaterThan(0); + expect(events[events.length - 1].type).toBe(EventType.RUN_FINISHED); + expect(await runner.isRunning({ threadId })).toBe(false); + }); + + it("closes open text and tool events when stopping", async () => { + const threadId = "sqlite-open-events"; + const agent = new OpenEventsAgent(); + const input: RunAgentInput = { + threadId, + runId: "sqlite-open-run", + messages: [], + state: {}, + }; + + const run$ = runner.run({ threadId, agent, input }); + const collected = firstValueFrom(run$.pipe(toArray())); + + await new Promise((resolve) => setTimeout(resolve, 20)); + await runner.stop({ threadId }); + + const events = await collected; + const endingTypes = events.slice(-4).map((event) => event.type); + expect(endingTypes).toEqual([ + EventType.TEXT_MESSAGE_END, + EventType.TOOL_CALL_END, + EventType.TOOL_CALL_RESULT, + EventType.RUN_FINISHED, ]); }); }); diff --git a/packages/sqlite-runner/src/sqlite-runner.ts b/packages/sqlite-runner/src/sqlite-runner.ts index 2ba165ad..90357ab6 100644 --- a/packages/sqlite-runner/src/sqlite-runner.ts +++ b/packages/sqlite-runner/src/sqlite-runner.ts @@ -1,5 +1,6 @@ import { AgentRunner, + finalizeRunEvents, type AgentRunnerConnectRequest, type AgentRunnerIsRunningRequest, type AgentRunnerRunRequest, @@ -7,6 +8,7 @@ import { } from "@copilotkitnext/runtime"; import { Observable, ReplaySubject } from "rxjs"; import { + AbstractAgent, BaseEvent, RunAgentInput, EventType, @@ -32,9 +34,16 @@ export interface SqliteAgentRunnerOptions { dbPath?: string; } -// Active connections for streaming events -// This is the only in-memory state we need - just for active streaming -const ACTIVE_CONNECTIONS = new Map>(); +interface ActiveConnectionContext { + subject: ReplaySubject; + agent?: AbstractAgent; + runSubject?: ReplaySubject; + currentEvents?: BaseEvent[]; + stopRequested?: boolean; +} + +// Active connections for streaming events and stop support +const ACTIVE_CONNECTIONS = new Map(); export class SqliteAgentRunner extends AgentRunner { private db: any; @@ -235,14 +244,21 @@ export class SqliteAgentRunner extends AgentRunner { // Get or create subject for this thread's connections const nextSubject = new ReplaySubject(Infinity); - const prevSubject = ACTIVE_CONNECTIONS.get(request.threadId); + const prevConnection = ACTIVE_CONNECTIONS.get(request.threadId); + const prevSubject = prevConnection?.subject; - // Update the active connection for this thread - ACTIVE_CONNECTIONS.set(request.threadId, nextSubject); - // Create a subject for run() return value const runSubject = new ReplaySubject(Infinity); + // Update the active connection for this thread + ACTIVE_CONNECTIONS.set(request.threadId, { + subject: nextSubject, + agent: request.agent, + runSubject, + currentEvents: currentRunEvents, + stopRequested: false, + }); + // Helper function to run the agent and handle errors const runAgent = async () => { // Get parent run ID for chaining @@ -295,6 +311,15 @@ export class SqliteAgentRunner extends AgentRunner { }, }); + const connection = ACTIVE_CONNECTIONS.get(request.threadId); + const appendedEvents = finalizeRunEvents(currentRunEvents, { + stopRequested: connection?.stopRequested ?? false, + }); + for (const event of appendedEvents) { + runSubject.next(event); + nextSubject.next(event); + } + // Store the run in database this.storeRun( request.threadId, @@ -306,11 +331,29 @@ export class SqliteAgentRunner extends AgentRunner { // Mark run as complete in database this.setRunState(request.threadId, false); - + + if (connection) { + connection.agent = undefined; + connection.runSubject = undefined; + connection.currentEvents = undefined; + connection.stopRequested = false; + } + // Complete the subjects runSubject.complete(); nextSubject.complete(); + + ACTIVE_CONNECTIONS.delete(request.threadId); } catch { + const connection = ACTIVE_CONNECTIONS.get(request.threadId); + const appendedEvents = finalizeRunEvents(currentRunEvents, { + stopRequested: connection?.stopRequested ?? false, + }); + for (const event of appendedEvents) { + runSubject.next(event); + nextSubject.next(event); + } + // Store the run even if it failed (partial events) if (currentRunEvents.length > 0) { this.storeRun( @@ -324,11 +367,20 @@ export class SqliteAgentRunner extends AgentRunner { // Mark run as complete in database this.setRunState(request.threadId, false); - + + if (connection) { + connection.agent = undefined; + connection.runSubject = undefined; + connection.currentEvents = undefined; + connection.stopRequested = false; + } + // Don't emit error to the subject, just complete it // This allows subscribers to get events emitted before the error runSubject.complete(); nextSubject.complete(); + + ACTIVE_CONNECTIONS.delete(request.threadId); } }; @@ -375,11 +427,11 @@ export class SqliteAgentRunner extends AgentRunner { } // Bridge active run to connection if exists - const activeSubject = ACTIVE_CONNECTIONS.get(request.threadId); + const activeConnection = ACTIVE_CONNECTIONS.get(request.threadId); const runState = this.getRunState(request.threadId); - - if (activeSubject && runState.isRunning) { - activeSubject.subscribe({ + + if (activeConnection && (runState.isRunning || activeConnection.stopRequested)) { + activeConnection.subject.subscribe({ next: (event) => { // Skip message events that we've already emitted from historic if ('messageId' in event && typeof event.messageId === 'string' && emittedMessageIds.has(event.messageId)) { @@ -403,9 +455,35 @@ export class SqliteAgentRunner extends AgentRunner { return Promise.resolve(runState.isRunning); } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - stop(_request: AgentRunnerStopRequest): Promise { - throw new Error("Method not implemented."); + stop(request: AgentRunnerStopRequest): Promise { + const runState = this.getRunState(request.threadId); + if (!runState.isRunning) { + return Promise.resolve(false); + } + + const connection = ACTIVE_CONNECTIONS.get(request.threadId); + const agent = connection?.agent; + + if (!connection || !agent) { + return Promise.resolve(false); + } + + if (connection.stopRequested) { + return Promise.resolve(false); + } + + connection.stopRequested = true; + this.setRunState(request.threadId, false); + + try { + agent.abortRun(); + return Promise.resolve(true); + } catch (error) { + console.error("Failed to abort sqlite agent run", error); + connection.stopRequested = false; + this.setRunState(request.threadId, true); + return Promise.resolve(false); + } } /**