Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions packages/agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ export interface BasicAgentConfiguration {
}

export class BasicAgent extends AbstractAgent {
private abortController?: AbortController;

constructor(private config: BasicAgentConfiguration) {
super();
}
Expand Down Expand Up @@ -620,6 +622,10 @@ export class BasicAgent extends AbstractAgent {
const mcpClients: Array<{ close: () => Promise<void> }> = [];

(async () => {
const abortController = new AbortController();
this.abortController = abortController;
let terminalEventEmitted = false;

try {
// Add AG-UI state update tools
streamTextParams.tools = {
Expand Down Expand Up @@ -681,7 +687,7 @@ export class BasicAgent extends AbstractAgent {
}

// Call streamText and process the stream
const response = streamText(streamTextParams);
const response = streamText({ ...streamTextParams, abortSignal: abortController.signal });

let messageId = randomUUID();

Expand Down Expand Up @@ -848,32 +854,60 @@ export class BasicAgent extends AbstractAgent {
runId: input.runId,
};
subscriber.next(finishedEvent);
terminalEventEmitted = true;

// Complete the observable
subscriber.complete();
break;

case "error":
case "error": {
if (abortController.signal.aborted) {
break;
}
const runErrorEvent: RunErrorEvent = {
type: EventType.RUN_ERROR,
message: part.error + "",
};
subscriber.next(runErrorEvent);
terminalEventEmitted = true;

// Handle error
subscriber.error(part.error);
break;
}
}
}
} catch (error) {
const runErrorEvent: RunErrorEvent = {
type: EventType.RUN_ERROR,
message: error + "",
};
subscriber.next(runErrorEvent);

subscriber.error(error);
if (!terminalEventEmitted) {
if (abortController.signal.aborted) {
// Let the runner finalize the stream on stop requests so it can
// inject consistent closing events and a RUN_FINISHED marker.
} else {
const finishedEvent: RunFinishedEvent = {
type: EventType.RUN_FINISHED,
threadId: input.threadId,
runId: input.runId,
};
subscriber.next(finishedEvent);
}

terminalEventEmitted = true;
subscriber.complete();
}
} catch (error) {
if (abortController.signal.aborted) {
subscriber.complete();
} else {
const runErrorEvent: RunErrorEvent = {
type: EventType.RUN_ERROR,
message: error + "",
};
subscriber.next(runErrorEvent);
terminalEventEmitted = true;
subscriber.error(error);
}
} finally {
this.abortController = undefined;
await Promise.all(mcpClients.map((client) => client.close()));
}
})();
Expand All @@ -891,4 +925,8 @@ export class BasicAgent extends AbstractAgent {
clone() {
return new BasicAgent(this.config);
}

abortRun(): void {
this.abortController?.abort();
}
}
33 changes: 27 additions & 6 deletions packages/core/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import {
} from "@ag-ui/client";
import { Observable } from "rxjs";

export interface ProxiedCopilotRuntimeAgentConfig
extends Omit<HttpAgentConfig, "url"> {
export interface ProxiedCopilotRuntimeAgentConfig extends Omit<HttpAgentConfig, "url"> {
runtimeUrl?: string;
}

Expand All @@ -24,11 +23,33 @@ export class ProxiedCopilotRuntimeAgent extends HttpAgent {
this.runtimeUrl = config.runtimeUrl;
}

abortRun(): void {
if (!this.runtimeUrl || !this.agentId || !this.threadId) {
return;
}

if (typeof fetch === "undefined") {
return;
}

const stopPath = `${this.runtimeUrl}/agent/${encodeURIComponent(this.agentId)}/stop/${encodeURIComponent(this.threadId)}`;
const origin = typeof window !== "undefined" && window.location ? window.location.origin : "http://localhost";
const base = new URL(this.runtimeUrl, origin);
const stopUrl = new URL(stopPath, base);

void fetch(stopUrl.toString(), {
method: "POST",
headers: {
"Content-Type": "application/json",
...this.headers,
},
}).catch((error) => {
console.error("ProxiedCopilotRuntimeAgent: stop request failed", error);
});
}

connect(input: RunAgentInput): Observable<BaseEvent> {
const httpEvents = runHttpRequest(
`${this.runtimeUrl}/agent/${this.agentId}/connect`,
this.requestInit(input)
);
const httpEvents = runHttpRequest(`${this.runtimeUrl}/agent/${this.agentId}/connect`, this.requestInit(input));
return transformHttpEventStream(httpEvents);
}
}
14 changes: 13 additions & 1 deletion packages/core/src/core/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ export interface CopilotKitCoreConfig {
}

export type { CopilotKitCoreAddAgentParams };
export type { CopilotKitCoreRunAgentParams, CopilotKitCoreConnectAgentParams, CopilotKitCoreGetToolParams };
export type {
CopilotKitCoreRunAgentParams,
CopilotKitCoreConnectAgentParams,
CopilotKitCoreGetToolParams,
};

export interface CopilotKitCoreStopAgentParams {
agent: AbstractAgent;
}

export type CopilotKitCoreGetSuggestionsResult = {
suggestions: Suggestion[];
Expand Down Expand Up @@ -395,6 +403,10 @@ export class CopilotKitCore {
return this.runHandler.connectAgent(params);
}

stopAgent(params: CopilotKitCoreStopAgentParams): void {
params.agent.abortRun();
}

async runAgent(params: CopilotKitCoreRunAgentParams): Promise<import("@ag-ui/client").RunAgentResult> {
return this.runHandler.runAgent(params);
}
Expand Down
37 changes: 33 additions & 4 deletions packages/react/src/components/chat/CopilotChat.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useAgent } from "@/hooks/use-agent";
import { useSuggestions } from "@/hooks/use-suggestions";
import { CopilotChatView, CopilotChatViewProps } from "./CopilotChatView";
import CopilotChatInput, { CopilotChatInputProps } from "./CopilotChatInput";
import {
CopilotChatConfigurationProvider,
CopilotChatLabels,
Expand Down Expand Up @@ -104,6 +105,23 @@ export function CopilotChat({ agentId, threadId, labels, chatView, isModalDefaul
[agent, copilotkit],
);

const stopCurrentRun = useCallback(() => {
if (!agent) {
return;
}

try {
copilotkit.stopAgent({ agent });
} catch (error) {
console.error("CopilotChat: stopAgent failed", error);
try {
agent.abortRun();
} catch (abortError) {
console.error("CopilotChat: abortRun fallback failed", abortError);
}
}
}, [agent, copilotkit]);

const mergedProps = merge(
{
isRunning: agent?.isRunning ?? false,
Expand All @@ -121,12 +139,23 @@ export function CopilotChat({ agentId, threadId, labels, chatView, isModalDefaul
},
);

const providedStopHandler = providedInputProps?.onStop;
const hasMessages = (agent?.messages?.length ?? 0) > 0;
const shouldAllowStop = (agent?.isRunning ?? false) && hasMessages;
const effectiveStopHandler = shouldAllowStop ? providedStopHandler ?? stopCurrentRun : providedStopHandler;

const finalInputProps = {
...providedInputProps,
onSubmitMessage: onSubmitInput,
onStop: effectiveStopHandler,
isRunning: agent?.isRunning ?? false,
} as Partial<CopilotChatInputProps> & { onSubmitMessage: (value: string) => void };

finalInputProps.mode = agent?.isRunning ? "processing" : finalInputProps.mode ?? "input";

const finalProps = merge(mergedProps, {
messages: agent?.messages ?? [],
inputProps: {
onSubmitMessage: onSubmitInput,
...providedInputProps,
},
inputProps: finalInputProps,
}) as CopilotChatViewProps;

// Always create a provider with merged values
Expand Down
35 changes: 29 additions & 6 deletions packages/react/src/components/chat/CopilotChatInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import React, {
useMemo,
} from "react";
import { twMerge } from "tailwind-merge";
import { Plus, Mic, ArrowUp, X, Check } from "lucide-react";
import { Plus, Mic, ArrowUp, X, Check, Square } from "lucide-react";

import {
CopilotChatLabels,
Expand Down Expand Up @@ -64,6 +64,8 @@ type CopilotChatInputRestProps = {
toolsMenu?: (ToolsMenuItem | "-")[];
autoFocus?: boolean;
onSubmitMessage?: (value: string) => void;
onStop?: () => void;
isRunning?: boolean;
onStartTranscribe?: () => void;
onCancelTranscribe?: () => void;
onFinishTranscribe?: () => void;
Expand All @@ -90,6 +92,8 @@ const SLASH_MENU_ITEM_HEIGHT_PX = 40;
export function CopilotChatInput({
mode = "input",
onSubmitMessage,
onStop,
isRunning = false,
onStartTranscribe,
onCancelTranscribe,
onFinishTranscribe,
Expand Down Expand Up @@ -390,7 +394,11 @@ export function CopilotChatInput({

if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
send();
if (isProcessing) {
onStop?.();
} else {
send();
}
}
};

Expand Down Expand Up @@ -427,13 +435,26 @@ export function CopilotChatInput({
),
});

const isProcessing = mode !== "transcribe" && isRunning;
const canSend = resolvedValue.trim().length > 0 && !!onSubmitMessage;
const canStop = !!onStop;

const handleSendButtonClick = () => {
if (isProcessing) {
onStop?.();
return;
}
send();
};

const BoundAudioRecorder = renderSlot(audioRecorder, CopilotChatAudioRecorder, {
ref: audioRecorderRef,
});

const BoundSendButton = renderSlot(sendButton, CopilotChatInput.SendButton, {
onClick: send,
disabled: !resolvedValue.trim() || !onSubmitMessage,
onClick: handleSendButtonClick,
disabled: isProcessing ? !canStop : !canSend,
children: isProcessing && canStop ? <Square className="size-[18px] fill-current" /> : undefined,
});

const BoundStartTranscribeButton = renderSlot(startTranscribeButton, CopilotChatInput.StartTranscribeButton, {
Expand Down Expand Up @@ -464,6 +485,8 @@ export function CopilotChatInput({
finishTranscribeButton: BoundFinishTranscribeButton,
addMenuButton: BoundAddMenuButton,
onSubmitMessage,
onStop,
isRunning,
onStartTranscribe,
onCancelTranscribe,
onFinishTranscribe,
Expand Down Expand Up @@ -833,7 +856,7 @@ export function CopilotChatInput({

// eslint-disable-next-line @typescript-eslint/no-namespace
export namespace CopilotChatInput {
export const SendButton: React.FC<React.ButtonHTMLAttributes<HTMLButtonElement>> = ({ className, ...props }) => (
export const SendButton: React.FC<React.ButtonHTMLAttributes<HTMLButtonElement>> = ({ className, children, ...props }) => (
<div className="mr-[10px]">
<Button
type="button"
Expand All @@ -842,7 +865,7 @@ export namespace CopilotChatInput {
className={className}
{...props}
>
<ArrowUp className="size-[18px]" />
{children ?? <ArrowUp className="size-[18px]" />}
</Button>
</div>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,24 @@ describe("InMemoryAgentRunner – run started inputs", () => {
runner.run({ threadId, agent, input }).pipe(toArray()),
);

expect(runEvents).toHaveLength(1);
expect(runEvents[0].type).toBe(EventType.RUN_STARTED);
const runStarted = runEvents[0] as RunStartedEvent;
expect(runStarted.type).toBe(EventType.RUN_STARTED);
expect(runStarted.input?.messages).toEqual(messages);

const terminalTypes = runEvents.slice(1).map((event) => event.type);
expect(terminalTypes.every((type) => type === EventType.RUN_ERROR || type === EventType.RUN_FINISHED)).toBe(true);

const connectEvents = await firstValueFrom(
runner.connect({ threadId }).pipe(toArray()),
);

expect(connectEvents).toHaveLength(1);
expect(connectEvents[0].type).toBe(EventType.RUN_STARTED);
const connectRunStarted = connectEvents[0] as RunStartedEvent;
expect(connectRunStarted.input?.messages).toEqual(messages);
const connectTerminalTypes = connectEvents.slice(1).map((event) => event.type);
expect(
connectTerminalTypes.every((type) => type === EventType.RUN_ERROR || type === EventType.RUN_FINISHED),
).toBe(true);
});

it("only includes new messages on subsequent runs", async () => {
Expand Down Expand Up @@ -149,8 +155,13 @@ describe("InMemoryAgentRunner – run started inputs", () => {
.pipe(toArray()),
);

expect(secondRunEvents[0].type).toBe(EventType.RUN_STARTED);
const runStarted = secondRunEvents[0] as RunStartedEvent;
expect(runStarted.input?.messages).toEqual([newMessage]);
const secondTerminalTypes = secondRunEvents.slice(1).map((event) => event.type);
expect(
secondTerminalTypes.every((type) => type === EventType.RUN_ERROR || type === EventType.RUN_FINISHED),
).toBe(true);

const connectEvents = await firstValueFrom(
runner.connect({ threadId }).pipe(toArray()),
Expand Down
Loading
Loading