Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: test

on:
push:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:

jobs:
unit:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- uses: pnpm/action-setup@v3
with:
version: 9

- uses: actions/setup-node@v4
with:
node-version: 20
cache: 'pnpm'
cache-dependency-path: pnpm-lock.yaml

- name: Install dependencies
run: pnpm install --frozen-lockfile

- name: Run tests
run: pnpm test
8 changes: 7 additions & 1 deletion packages/core/src/__tests__/core-full.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ describe("CopilotKitCore.runAgent - Full Test Suite", () => {

await copilotKitCore.runAgent({ agent: agent as any });

expect(tool.handler).toHaveBeenCalledWith({ input: "test" });
expect(tool.handler).toHaveBeenCalledTimes(1);
const [firstCallArgs] = tool.handler.mock.calls;
expect(firstCallArgs?.[0]).toEqual({ input: "test" });
expect(firstCallArgs?.[1]).toMatchObject({
function: { name: toolName },
type: "function",
});
expect(agent.messages.some(m => m.role === "tool")).toBe(true);
});

Expand Down
257 changes: 257 additions & 0 deletions packages/core/src/__tests__/core-headers.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { CopilotKitCore } from "../core";
import { HttpAgent } from "@ag-ui/client";
import { waitForCondition } from "./test-utils";

describe("CopilotKitCore headers", () => {
const originalFetch = global.fetch;

beforeEach(() => {
vi.restoreAllMocks();
});

afterEach(() => {
if (originalFetch) {
global.fetch = originalFetch;
} else {
delete (global as typeof globalThis & { fetch?: typeof fetch }).fetch;
}
});

it("includes provided headers when fetching runtime info", async () => {
const fetchMock = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({ version: "1.0.0", agents: {} }),
});
global.fetch = fetchMock as unknown as typeof fetch;

const headers = {
Authorization: "Bearer test-token",
"X-Custom-Header": "custom-value",
};

// eslint-disable-next-line @typescript-eslint/no-unused-vars
const core = new CopilotKitCore({
runtimeUrl: "https://runtime.example",
headers,
});

await waitForCondition(() => fetchMock.mock.calls.length >= 1);

expect(fetchMock).toHaveBeenCalledWith(
"https://runtime.example/info",
expect.objectContaining({
headers: expect.objectContaining(headers),
})
);
});

it("uses updated headers for subsequent runtime requests", async () => {
const fetchMock = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({ version: "1.0.0", agents: {} }),
});
global.fetch = fetchMock as unknown as typeof fetch;

const core = new CopilotKitCore({
runtimeUrl: "https://runtime.example",
headers: { Authorization: "Bearer initial" },
});

await waitForCondition(() => fetchMock.mock.calls.length >= 1);

core.setHeaders({ Authorization: "Bearer updated", "X-Trace": "123" });
core.setRuntimeUrl(undefined);
core.setRuntimeUrl("https://runtime.example");

await waitForCondition(() => fetchMock.mock.calls.length >= 2);

const secondCall = fetchMock.mock.calls[1];
expect(secondCall?.[1]?.headers).toMatchObject({
Authorization: "Bearer updated",
"X-Trace": "123",
});
});

it("passes configured headers to HttpAgent runs", async () => {
const recorded: Array<Record<string, string>> = [];

class RecordingHttpAgent extends HttpAgent {
constructor() {
super({ url: "https://runtime.example" });
}

async connectAgent(...args: Parameters<HttpAgent["connectAgent"]>) {
recorded.push({ ...this.headers });
return Promise.resolve({ newMessages: [] }) as ReturnType<HttpAgent["connectAgent"]>;
}

async runAgent(...args: Parameters<HttpAgent["runAgent"]>) {
recorded.push({ ...this.headers });
return Promise.resolve({ newMessages: [] }) as ReturnType<HttpAgent["runAgent"]>;
}
}

const agent = new RecordingHttpAgent();

const core = new CopilotKitCore({
runtimeUrl: undefined,
headers: { Authorization: "Bearer cfg", "X-Team": "angular" },
agents: { default: agent },
});

await agent.runAgent();
await core.connectAgent({ agent, agentId: "default" });
await core.runAgent({ agent, agentId: "default" });

expect(recorded).toHaveLength(3);
for (const headers of recorded) {
expect(headers).toMatchObject({
Authorization: "Bearer cfg",
"X-Team": "angular",
});
}
});

it("applies updated headers to existing HttpAgent instances", () => {
const agent = new HttpAgent({ url: "https://runtime.example" });

const core = new CopilotKitCore({
runtimeUrl: undefined,
headers: { Authorization: "Bearer cfg" },
agents: { default: agent },
});

expect(agent.headers).toMatchObject({
Authorization: "Bearer cfg",
});

core.setHeaders({
Authorization: "Bearer updated",
"X-Trace": "123",
});

expect(agent.headers).toMatchObject({
Authorization: "Bearer updated",
"X-Trace": "123",
});
});

it("applies headers to agents provided via setAgents", () => {
const originalAgent = new HttpAgent({ url: "https://runtime.example/original" });
const replacementAgent = new HttpAgent({
url: "https://runtime.example/replacement",
});

const core = new CopilotKitCore({
runtimeUrl: undefined,
headers: { Authorization: "Bearer cfg" },
agents: { original: originalAgent },
});

expect(originalAgent.headers).toMatchObject({
Authorization: "Bearer cfg",
});

core.setAgents({ replacement: replacementAgent });

expect(replacementAgent.headers).toMatchObject({
Authorization: "Bearer cfg",
});
});

it("applies headers when agents are added dynamically", () => {
const core = new CopilotKitCore({
runtimeUrl: undefined,
headers: { Authorization: "Bearer cfg" },
});

const addedAgent = new HttpAgent({ url: "https://runtime.example/new" });

core.addAgent({ id: "added", agent: addedAgent });

expect(addedAgent.headers).toMatchObject({
Authorization: "Bearer cfg",
});
});

it("uses the latest headers when running HttpAgent instances", async () => {
const recorded: Array<Record<string, string>> = [];

class RecordingHttpAgent extends HttpAgent {
constructor() {
super({ url: "https://runtime.example" });
}

async runAgent(...args: Parameters<HttpAgent["runAgent"]>) {
recorded.push({ ...this.headers });
return Promise.resolve({ newMessages: [] }) as ReturnType<HttpAgent["runAgent"]>;
}
}

const agent = new RecordingHttpAgent();

const core = new CopilotKitCore({
runtimeUrl: undefined,
headers: { Authorization: "Bearer initial" },
agents: { default: agent },
});

await core.runAgent({ agent, agentId: "default" });

core.setHeaders({ Authorization: "Bearer updated", "X-Trace": "123" });

await core.runAgent({ agent, agentId: "default" });

expect(recorded).toHaveLength(2);
expect(recorded[0]).toMatchObject({ Authorization: "Bearer initial" });
expect(recorded[1]).toMatchObject({
Authorization: "Bearer updated",
"X-Trace": "123",
});
});

it("applies headers to remote agents fetched from runtime info", async () => {
const fetchMock = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({
version: "1.0.0",
agents: {
remote: {
name: "Remote Agent",
className: "RemoteClass",
description: "Remote description",
},
},
}),
});
global.fetch = fetchMock as unknown as typeof fetch;

const core = new CopilotKitCore({
runtimeUrl: "https://runtime.example",
headers: { Authorization: "Bearer cfg", "X-Team": "angular" },
});

await waitForCondition(() => core.getAgent("remote") !== undefined);

const remoteAgent = core.getAgent("remote") as HttpAgent | undefined;
expect(remoteAgent).toBeDefined();
expect(remoteAgent?.headers).toMatchObject({
Authorization: "Bearer cfg",
"X-Team": "angular",
});

core.setHeaders({ Authorization: "Bearer updated" });

expect(remoteAgent?.headers).toMatchObject({
Authorization: "Bearer updated",
});

expect(fetchMock).toHaveBeenCalledWith(
"https://runtime.example/info",
expect.objectContaining({
headers: expect.objectContaining({
Authorization: "Bearer cfg",
"X-Team": "angular",
}),
})
);
});
});
8 changes: 7 additions & 1 deletion packages/core/src/__tests__/core-tool-minimal.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ describe("CopilotKitCore Tool Minimal", () => {

await copilotKitCore.runAgent({ agent: agent as any });

expect(tool.handler).toHaveBeenCalledWith({ input: "test" });
expect(tool.handler).toHaveBeenCalledTimes(1);
const [firstCallArgs] = tool.handler.mock.calls;
expect(firstCallArgs?.[0]).toEqual({ input: "test" });
expect(firstCallArgs?.[1]).toMatchObject({
function: { name: toolName },
type: "function",
});
expect(agent.messages.some(m => m.role === "tool")).toBe(true);
});

Expand Down
8 changes: 7 additions & 1 deletion packages/core/src/__tests__/core-tool-simple.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ describe("CopilotKitCore Tool Simple", () => {
await copilotKitCore.runAgent({ agent: agent as any });
console.log("Agent run complete");

expect(tool.handler).toHaveBeenCalledWith({ input: "test" });
expect(tool.handler).toHaveBeenCalledTimes(1);
const [firstCallArgs] = tool.handler.mock.calls;
expect(firstCallArgs?.[0]).toEqual({ input: "test" });
expect(firstCallArgs?.[1]).toMatchObject({
function: { name: toolName },
type: "function",
});
expect(agent.messages.length).toBeGreaterThan(0);
});
});
17 changes: 17 additions & 0 deletions packages/core/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,24 @@ export class CopilotKitCore {
this.headers = headers;
this.properties = properties;
this.localAgents = this.assignAgentIds(agents);
this.applyHeadersToAgents(this.localAgents);
this._agents = this.localAgents;
this._tools = tools;
this.setRuntimeUrl(runtimeUrl);
}

private applyHeadersToAgent(agent: AbstractAgent) {
if (agent instanceof HttpAgent) {
agent.headers = { ...this.headers };
}
}

private applyHeadersToAgents(agents: Record<string, AbstractAgent>) {
Object.values(agents).forEach((agent) => {
this.applyHeadersToAgent(agent);
});
}

private assignAgentIds(agents: Record<string, AbstractAgent>) {
Object.entries(agents).forEach(([id, agent]) => {
if (agent && !agent.agentId) {
Expand Down Expand Up @@ -314,6 +327,7 @@ export class CopilotKitCore {
agentId: id,
description: description,
});
this.applyHeadersToAgent(agent);
return [id, agent];
})
);
Expand Down Expand Up @@ -387,6 +401,7 @@ export class CopilotKitCore {
*/
setHeaders(headers: Record<string, string>) {
this.headers = headers;
this.applyHeadersToAgents(this._agents);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this.applyHeadersToAgents(this._agents) is enough

void this.notifySubscribers(
(subscriber) =>
subscriber.onHeadersChanged?.({
Expand All @@ -412,6 +427,7 @@ export class CopilotKitCore {
setAgents(agents: Record<string, AbstractAgent>) {
this.localAgents = this.assignAgentIds(agents);
this._agents = { ...this.localAgents, ...this.remoteAgents };
this.applyHeadersToAgents(this._agents);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying to this._agents should be enough (it contains all agents)

void this.notifySubscribers(
(subscriber) =>
subscriber.onAgentsChanged?.({
Expand All @@ -427,6 +443,7 @@ export class CopilotKitCore {
if (!agent.agentId) {
agent.agentId = id;
}
this.applyHeadersToAgent(agent);
this._agents = { ...this.localAgents, ...this.remoteAgents };
void this.notifySubscribers(
(subscriber) =>
Expand Down
Loading