diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..6498ced3 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,33 @@ +name: test + +on: + push: + branches: + - main + pull_request: + branches: + - main + workflow_dispatch: + +jobs: + unit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: pnpm/action-setup@v3 + with: + version: 9 + + - uses: actions/setup-node@v4 + with: + node-version: 20 + cache: 'pnpm' + cache-dependency-path: pnpm-lock.yaml + + - name: Install dependencies + run: pnpm install --frozen-lockfile + + - name: Run tests + run: pnpm test diff --git a/packages/core/src/__tests__/core-full.test.ts b/packages/core/src/__tests__/core-full.test.ts index 1b084617..529ca39c 100644 --- a/packages/core/src/__tests__/core-full.test.ts +++ b/packages/core/src/__tests__/core-full.test.ts @@ -51,7 +51,13 @@ describe("CopilotKitCore.runAgent - Full Test Suite", () => { await copilotKitCore.runAgent({ agent: agent as any }); - expect(tool.handler).toHaveBeenCalledWith({ input: "test" }); + expect(tool.handler).toHaveBeenCalledTimes(1); + const [firstCallArgs] = tool.handler.mock.calls; + expect(firstCallArgs?.[0]).toEqual({ input: "test" }); + expect(firstCallArgs?.[1]).toMatchObject({ + function: { name: toolName }, + type: "function", + }); expect(agent.messages.some(m => m.role === "tool")).toBe(true); }); diff --git a/packages/core/src/__tests__/core-headers.test.ts b/packages/core/src/__tests__/core-headers.test.ts new file mode 100644 index 00000000..a3ab8ab2 --- /dev/null +++ b/packages/core/src/__tests__/core-headers.test.ts @@ -0,0 +1,257 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { CopilotKitCore } from "../core"; +import { HttpAgent } from "@ag-ui/client"; +import { waitForCondition } from "./test-utils"; + +describe("CopilotKitCore headers", () => { + const originalFetch = global.fetch; + + beforeEach(() => { + vi.restoreAllMocks(); + }); + + afterEach(() => { + if (originalFetch) { + global.fetch = originalFetch; + } else { + delete (global as typeof globalThis & { fetch?: typeof fetch }).fetch; + } + }); + + it("includes provided headers when fetching runtime info", async () => { + const fetchMock = vi.fn().mockResolvedValue({ + json: vi.fn().mockResolvedValue({ version: "1.0.0", agents: {} }), + }); + global.fetch = fetchMock as unknown as typeof fetch; + + const headers = { + Authorization: "Bearer test-token", + "X-Custom-Header": "custom-value", + }; + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const core = new CopilotKitCore({ + runtimeUrl: "https://runtime.example", + headers, + }); + + await waitForCondition(() => fetchMock.mock.calls.length >= 1); + + expect(fetchMock).toHaveBeenCalledWith( + "https://runtime.example/info", + expect.objectContaining({ + headers: expect.objectContaining(headers), + }) + ); + }); + + it("uses updated headers for subsequent runtime requests", async () => { + const fetchMock = vi.fn().mockResolvedValue({ + json: vi.fn().mockResolvedValue({ version: "1.0.0", agents: {} }), + }); + global.fetch = fetchMock as unknown as typeof fetch; + + const core = new CopilotKitCore({ + runtimeUrl: "https://runtime.example", + headers: { Authorization: "Bearer initial" }, + }); + + await waitForCondition(() => fetchMock.mock.calls.length >= 1); + + core.setHeaders({ Authorization: "Bearer updated", "X-Trace": "123" }); + core.setRuntimeUrl(undefined); + core.setRuntimeUrl("https://runtime.example"); + + await waitForCondition(() => fetchMock.mock.calls.length >= 2); + + const secondCall = fetchMock.mock.calls[1]; + expect(secondCall?.[1]?.headers).toMatchObject({ + Authorization: "Bearer updated", + "X-Trace": "123", + }); + }); + + it("passes configured headers to HttpAgent runs", async () => { + const recorded: Array> = []; + + class RecordingHttpAgent extends HttpAgent { + constructor() { + super({ url: "https://runtime.example" }); + } + + async connectAgent(...args: Parameters) { + recorded.push({ ...this.headers }); + return Promise.resolve({ newMessages: [] }) as ReturnType; + } + + async runAgent(...args: Parameters) { + recorded.push({ ...this.headers }); + return Promise.resolve({ newMessages: [] }) as ReturnType; + } + } + + const agent = new RecordingHttpAgent(); + + const core = new CopilotKitCore({ + runtimeUrl: undefined, + headers: { Authorization: "Bearer cfg", "X-Team": "angular" }, + agents: { default: agent }, + }); + + await agent.runAgent(); + await core.connectAgent({ agent, agentId: "default" }); + await core.runAgent({ agent, agentId: "default" }); + + expect(recorded).toHaveLength(3); + for (const headers of recorded) { + expect(headers).toMatchObject({ + Authorization: "Bearer cfg", + "X-Team": "angular", + }); + } + }); + + it("applies updated headers to existing HttpAgent instances", () => { + const agent = new HttpAgent({ url: "https://runtime.example" }); + + const core = new CopilotKitCore({ + runtimeUrl: undefined, + headers: { Authorization: "Bearer cfg" }, + agents: { default: agent }, + }); + + expect(agent.headers).toMatchObject({ + Authorization: "Bearer cfg", + }); + + core.setHeaders({ + Authorization: "Bearer updated", + "X-Trace": "123", + }); + + expect(agent.headers).toMatchObject({ + Authorization: "Bearer updated", + "X-Trace": "123", + }); + }); + + it("applies headers to agents provided via setAgents", () => { + const originalAgent = new HttpAgent({ url: "https://runtime.example/original" }); + const replacementAgent = new HttpAgent({ + url: "https://runtime.example/replacement", + }); + + const core = new CopilotKitCore({ + runtimeUrl: undefined, + headers: { Authorization: "Bearer cfg" }, + agents: { original: originalAgent }, + }); + + expect(originalAgent.headers).toMatchObject({ + Authorization: "Bearer cfg", + }); + + core.setAgents({ replacement: replacementAgent }); + + expect(replacementAgent.headers).toMatchObject({ + Authorization: "Bearer cfg", + }); + }); + + it("applies headers when agents are added dynamically", () => { + const core = new CopilotKitCore({ + runtimeUrl: undefined, + headers: { Authorization: "Bearer cfg" }, + }); + + const addedAgent = new HttpAgent({ url: "https://runtime.example/new" }); + + core.addAgent({ id: "added", agent: addedAgent }); + + expect(addedAgent.headers).toMatchObject({ + Authorization: "Bearer cfg", + }); + }); + + it("uses the latest headers when running HttpAgent instances", async () => { + const recorded: Array> = []; + + class RecordingHttpAgent extends HttpAgent { + constructor() { + super({ url: "https://runtime.example" }); + } + + async runAgent(...args: Parameters) { + recorded.push({ ...this.headers }); + return Promise.resolve({ newMessages: [] }) as ReturnType; + } + } + + const agent = new RecordingHttpAgent(); + + const core = new CopilotKitCore({ + runtimeUrl: undefined, + headers: { Authorization: "Bearer initial" }, + agents: { default: agent }, + }); + + await core.runAgent({ agent, agentId: "default" }); + + core.setHeaders({ Authorization: "Bearer updated", "X-Trace": "123" }); + + await core.runAgent({ agent, agentId: "default" }); + + expect(recorded).toHaveLength(2); + expect(recorded[0]).toMatchObject({ Authorization: "Bearer initial" }); + expect(recorded[1]).toMatchObject({ + Authorization: "Bearer updated", + "X-Trace": "123", + }); + }); + + it("applies headers to remote agents fetched from runtime info", async () => { + const fetchMock = vi.fn().mockResolvedValue({ + json: vi.fn().mockResolvedValue({ + version: "1.0.0", + agents: { + remote: { + name: "Remote Agent", + className: "RemoteClass", + description: "Remote description", + }, + }, + }), + }); + global.fetch = fetchMock as unknown as typeof fetch; + + const core = new CopilotKitCore({ + runtimeUrl: "https://runtime.example", + headers: { Authorization: "Bearer cfg", "X-Team": "angular" }, + }); + + await waitForCondition(() => core.getAgent("remote") !== undefined); + + const remoteAgent = core.getAgent("remote") as HttpAgent | undefined; + expect(remoteAgent).toBeDefined(); + expect(remoteAgent?.headers).toMatchObject({ + Authorization: "Bearer cfg", + "X-Team": "angular", + }); + + core.setHeaders({ Authorization: "Bearer updated" }); + + expect(remoteAgent?.headers).toMatchObject({ + Authorization: "Bearer updated", + }); + + expect(fetchMock).toHaveBeenCalledWith( + "https://runtime.example/info", + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer cfg", + "X-Team": "angular", + }), + }) + ); + }); +}); diff --git a/packages/core/src/__tests__/core-tool-minimal.test.ts b/packages/core/src/__tests__/core-tool-minimal.test.ts index e8b4bb2d..6e8e8c9d 100644 --- a/packages/core/src/__tests__/core-tool-minimal.test.ts +++ b/packages/core/src/__tests__/core-tool-minimal.test.ts @@ -26,7 +26,13 @@ describe("CopilotKitCore Tool Minimal", () => { await copilotKitCore.runAgent({ agent: agent as any }); - expect(tool.handler).toHaveBeenCalledWith({ input: "test" }); + expect(tool.handler).toHaveBeenCalledTimes(1); + const [firstCallArgs] = tool.handler.mock.calls; + expect(firstCallArgs?.[0]).toEqual({ input: "test" }); + expect(firstCallArgs?.[1]).toMatchObject({ + function: { name: toolName }, + type: "function", + }); expect(agent.messages.some(m => m.role === "tool")).toBe(true); }); diff --git a/packages/core/src/__tests__/core-tool-simple.test.ts b/packages/core/src/__tests__/core-tool-simple.test.ts index c4d5c1e1..874a594d 100644 --- a/packages/core/src/__tests__/core-tool-simple.test.ts +++ b/packages/core/src/__tests__/core-tool-simple.test.ts @@ -34,7 +34,13 @@ describe("CopilotKitCore Tool Simple", () => { await copilotKitCore.runAgent({ agent: agent as any }); console.log("Agent run complete"); - expect(tool.handler).toHaveBeenCalledWith({ input: "test" }); + expect(tool.handler).toHaveBeenCalledTimes(1); + const [firstCallArgs] = tool.handler.mock.calls; + expect(firstCallArgs?.[0]).toEqual({ input: "test" }); + expect(firstCallArgs?.[1]).toMatchObject({ + function: { name: toolName }, + type: "function", + }); expect(agent.messages.length).toBeGreaterThan(0); }); }); \ No newline at end of file diff --git a/packages/core/src/core.ts b/packages/core/src/core.ts index b81f979b..e9345eac 100644 --- a/packages/core/src/core.ts +++ b/packages/core/src/core.ts @@ -141,11 +141,24 @@ export class CopilotKitCore { this.headers = headers; this.properties = properties; this.localAgents = this.assignAgentIds(agents); + this.applyHeadersToAgents(this.localAgents); this._agents = this.localAgents; this._tools = tools; this.setRuntimeUrl(runtimeUrl); } + private applyHeadersToAgent(agent: AbstractAgent) { + if (agent instanceof HttpAgent) { + agent.headers = { ...this.headers }; + } + } + + private applyHeadersToAgents(agents: Record) { + Object.values(agents).forEach((agent) => { + this.applyHeadersToAgent(agent); + }); + } + private assignAgentIds(agents: Record) { Object.entries(agents).forEach(([id, agent]) => { if (agent && !agent.agentId) { @@ -314,6 +327,7 @@ export class CopilotKitCore { agentId: id, description: description, }); + this.applyHeadersToAgent(agent); return [id, agent]; }) ); @@ -387,6 +401,7 @@ export class CopilotKitCore { */ setHeaders(headers: Record) { this.headers = headers; + this.applyHeadersToAgents(this._agents); void this.notifySubscribers( (subscriber) => subscriber.onHeadersChanged?.({ @@ -412,6 +427,7 @@ export class CopilotKitCore { setAgents(agents: Record) { this.localAgents = this.assignAgentIds(agents); this._agents = { ...this.localAgents, ...this.remoteAgents }; + this.applyHeadersToAgents(this._agents); void this.notifySubscribers( (subscriber) => subscriber.onAgentsChanged?.({ @@ -427,6 +443,7 @@ export class CopilotKitCore { if (!agent.agentId) { agent.agentId = id; } + this.applyHeadersToAgent(agent); this._agents = { ...this.localAgents, ...this.remoteAgents }; void this.notifySubscribers( (subscriber) =>