diff --git a/examples/x402-mcp/src/index.ts b/examples/x402-mcp/src/index.ts index 4785d88d..4ec889ef 100644 --- a/examples/x402-mcp/src/index.ts +++ b/examples/x402-mcp/src/index.ts @@ -35,7 +35,11 @@ export class PayAgent extends Agent { ); console.log("Agent will pay from this address:", account.address); - const { id } = await this.mcp.connect("http://localhost:8787/mcp"); + const { id } = await this.addMcpServer( + "x402", + "http://localhost:8787/mcp", + "http://localhost:3000" + ); // Build the x402 MCP client this.x402Client = withX402Client(this.mcp.mcpConnections[id].client, { diff --git a/package-lock.json b/package-lock.json index 16a192e4..3cea3001 100644 --- a/package-lock.json +++ b/package-lock.json @@ -768,6 +768,7 @@ "version": "2.0.64", "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.64.tgz", "integrity": "sha512-+1mqxn42uB32DPZ6kurSyGAmL3MgCaDpkYU7zNDWI4NLy3Zg97RxTsI1jBCGIqkEVvRZKJlIMYtb89OvMnq3AQ==", + "dev": true, "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "2.0.0", @@ -784,6 +785,7 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", "integrity": "sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==", + "dev": true, "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -796,6 +798,7 @@ "version": "3.0.16", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.16.tgz", "integrity": "sha512-lsWQY9aDXHitw7C1QRYIbVGmgwyT98TF3MfM8alNIXKpdJdi+W782Rzd9f1RyOfgRmZ08gJ2EYNDhWNK7RqpEA==", + "dev": true, "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "2.0.0", @@ -24168,7 +24171,6 @@ "version": "0.2.23", "license": "MIT", "dependencies": { - "@ai-sdk/openai": "2.0.64", "@cfworker/json-schema": "^4.1.1", "@modelcontextprotocol/sdk": "^1.21.0", "ai": "5.0.89", diff --git a/packages/agents/package.json b/packages/agents/package.json index 3e9fd542..8325e2e4 100644 --- a/packages/agents/package.json +++ b/packages/agents/package.json @@ -24,7 +24,6 @@ "url": "https://github.com/cloudflare/agents/issues" }, "dependencies": { - "@ai-sdk/openai": "2.0.64", "@cfworker/json-schema": "^4.1.1", "@modelcontextprotocol/sdk": "^1.21.0", "ai": "5.0.89", diff --git a/packages/agents/src/ai-chat-agent.ts b/packages/agents/src/ai-chat-agent.ts index 918ba0b9..63534fa7 100644 --- a/packages/agents/src/ai-chat-agent.ts +++ b/packages/agents/src/ai-chat-agent.ts @@ -190,10 +190,12 @@ export class AIChatAgent extends Agent< override async onRequest(request: Request): Promise { return this._tryCatchChat(() => { const url = new URL(request.url); + if (url.pathname.endsWith("/get-messages")) { const messages = this._loadMessagesFromDb(); return Response.json(messages); } + return super.onRequest(request); }); } diff --git a/packages/agents/src/index.ts b/packages/agents/src/index.ts index 6a5b2caa..54e82318 100644 --- a/packages/agents/src/index.ts +++ b/packages/agents/src/index.ts @@ -22,11 +22,14 @@ import { routePartykitRequest } from "partyserver"; import { camelCaseToKebabCase } from "./client"; -import { MCPClientManager, type MCPClientOAuthResult } from "./mcp/client"; -import { MCPClientConnection } from "./mcp/client-connection"; +import { + MCPClientManager, + type MCPClientOAuthResult +} from "./mcp/client-manager"; import type { MCPConnectionState } from "./mcp/client-connection"; import { DurableObjectOAuthClientProvider } from "./mcp/do-oauth-client-provider"; import type { TransportType } from "./mcp/types"; +import { AgentMCPStorageAdapter } from "./mcp/client-storage"; import { genericObservability, type Observability } from "./observability"; import { DisposableStore } from "./core/events"; import { MessageType } from "./ai-types"; @@ -230,20 +233,6 @@ export type MCPServer = { instructions: string | null; capabilities: ServerCapabilities | null; }; - -/** - * MCP Server data stored in DO SQL for resuming MCP Server connections - */ -type MCPServerRow = { - id: string; - name: string; - server_url: string; - client_id: string | null; - auth_url: string | null; - callback_url: string; - server_options: string; -}; - const STATE_ROW_ID = "cf_state_row_id"; const STATE_WAS_CHANGED = "cf_state_was_changed"; @@ -320,15 +309,11 @@ export class Agent< > extends Server { private _state = DEFAULT_STATE as State; private _disposables = new DisposableStore(); - private _mcpStateRestored = false; private _ParentClass: typeof Agent = Object.getPrototypeOf(this).constructor; - readonly mcp: MCPClientManager = new MCPClientManager( - this._ParentClass.name, - "0.0.1" - ); + readonly mcp!: MCPClientManager; /** * Initial state for the Agent @@ -421,6 +406,13 @@ export class Agent< constructor(ctx: AgentContext, env: Env) { super(ctx, env); + this.mcp = new MCPClientManager(this._ParentClass.name, "0.0.1", { + storage: new AgentMCPStorageAdapter( + this.sql.bind(this), + this.ctx.storage.kv + ) + }); + if (!wrappedClasses.has(this.constructor)) { // Auto-wrap custom methods with agent context this._autoWrapCustomMethods(); @@ -478,36 +470,28 @@ export class Agent< }); }); - this.sql` - CREATE TABLE IF NOT EXISTS cf_agents_mcp_servers ( - id TEXT PRIMARY KEY NOT NULL, - name TEXT NOT NULL, - server_url TEXT NOT NULL, - callback_url TEXT NOT NULL, - client_id TEXT, - auth_url TEXT, - server_options TEXT - ) - `; - const _onRequest = this.onRequest.bind(this); this.onRequest = (request: Request) => { return agentContext.run( { agent: this, connection: undefined, request, email: undefined }, async () => { - await this._ensureMcpStateRestored(); + await this.mcp.ensureJsonSchema(); + + const isCallback = await this.mcp.isCallbackRequest(request); - if (this.mcp.isCallbackRequest(request)) { + if (isCallback) { const result = await this.mcp.handleCallbackRequest(request); + this.broadcastMcpServers(); if (result.authSuccess) { - this.clearMcpServerAuthUrl(result.serverId); - this.mcp .establishConnection(result.serverId) .catch((error) => { - console.error("Background connection failed:", error); + console.error( + "[Agent onRequest] Background connection failed:", + error + ); }) .finally(() => { this.broadcastMcpServers(); @@ -527,6 +511,7 @@ export class Agent< return agentContext.run( { agent: this, connection, request: undefined, email: undefined }, async () => { + await this.mcp.ensureJsonSchema(); if (typeof message !== "string") { return this._tryCatch(() => _onMessage(connection, message)); } @@ -663,7 +648,7 @@ export class Agent< }, async () => { await this._tryCatch(async () => { - await this._ensureMcpStateRestored(); + await this.mcp.restoreConnectionsFromStorage(this.name); this.broadcastMcpServers(); return _onStart(props); }); @@ -1348,14 +1333,16 @@ export class Agent< // drop all tables this.sql`DROP TABLE IF EXISTS cf_agents_state`; this.sql`DROP TABLE IF EXISTS cf_agents_schedules`; - this.sql`DROP TABLE IF EXISTS cf_agents_mcp_servers`; this.sql`DROP TABLE IF EXISTS cf_agents_queues`; // delete all alarms await this.ctx.storage.deleteAlarm(); await this.ctx.storage.deleteAll(); this._disposables.dispose(); - await this.mcp.dispose?.(); + + // clean up MCP client manager + await this.mcp.dispose(); + this.ctx.abort("destroyed"); // enforce that the agent is evicted this.observability?.emit( @@ -1378,85 +1365,6 @@ export class Agent< return callableMetadata.has(this[method as keyof this] as Function); } - private async _ensureMcpStateRestored() { - if (this._mcpStateRestored) { - return; - } - - this._mcpStateRestored = true; - - const servers = this.sql` - SELECT id, name, server_url, client_id, auth_url, callback_url, server_options - FROM cf_agents_mcp_servers - `; - - if (!servers || !Array.isArray(servers) || servers.length === 0) { - return; - } - - for (const server of servers) { - if (server.callback_url) { - this.mcp.registerCallbackUrl(`${server.callback_url}/${server.id}`); - } - } - - for (const server of servers) { - const needsOAuth = !!server.auth_url; - - if (needsOAuth) { - const authProvider = new DurableObjectOAuthClientProvider( - this.ctx.storage, - this.name, - server.callback_url - ); - authProvider.serverId = server.id; - if (server.client_id) { - authProvider.clientId = server.client_id; - } - - const parsedOptions = server.server_options - ? JSON.parse(server.server_options) - : undefined; - - const conn = new MCPClientConnection( - new URL(server.server_url), - { - name: this.name, - version: "1.0.0" - }, - { - client: parsedOptions?.client ?? {}, - transport: { - ...(parsedOptions?.transport ?? {}), - type: parsedOptions?.transport?.type ?? ("auto" as TransportType), - authProvider - } - } - ); - - conn.connectionState = "authenticating"; - this.mcp.mcpConnections[server.id] = conn; - } else { - const parsedOptions = server.server_options - ? JSON.parse(server.server_options) - : undefined; - - this._connectToMcpServerInternal( - server.name, - server.server_url, - server.callback_url, - parsedOptions, - { - id: server.id, - oauthClientId: server.client_id ?? undefined - } - ).catch((error) => { - console.error(`Error restoring ${server.id}:`, error); - }); - } - } - } - /** * Connect to a new MCP Server * @@ -1497,72 +1405,17 @@ export class Agent< const callbackUrl = `${resolvedCallbackHost}/${agentsPrefix}/${camelCaseToKebabCase(this._ParentClass.name)}/${this.name}/callback`; - const result = await this._connectToMcpServerInternal( - serverName, - url, - callbackUrl, - options - ); - - this.sql` - INSERT - OR REPLACE INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES ( - ${result.id}, - ${serverName}, - ${url}, - ${result.clientId ?? null}, - ${result.authUrl ?? null}, - ${callbackUrl}, - ${options ? JSON.stringify(options) : null} - ); - `; - - this.broadcastMcpServers(); + // Late initialization of jsonSchemaFn (needed for getAITools) + await this.mcp.ensureJsonSchema(); - return result; - } + const id = nanoid(8); - private async _connectToMcpServerInternal( - _serverName: string, - url: string, - callbackUrl: string, - // it's important that any options here are serializable because we put them into our sqlite DB for reconnection purposes - options?: { - client?: ConstructorParameters[1]; - /** - * We don't expose the normal set of transport options because: - * 1) we can't serialize things like the auth provider or a fetch function into the DB for reconnection purposes - * 2) We probably want these options to be agnostic to the transport type (SSE vs Streamable) - * - * This has the limitation that you can't override fetch, but I think headers should handle nearly all cases needed (i.e. non-standard bearer auth). - */ - transport?: { - headers?: HeadersInit; - type?: TransportType; - }; - }, - reconnect?: { - id: string; - oauthClientId?: string; - } - ): Promise<{ - id: string; - authUrl: string | undefined; - clientId: string | undefined; - }> { const authProvider = new DurableObjectOAuthClientProvider( - this.ctx.storage, + this.mcp.storage, this.name, callbackUrl ); - - if (reconnect) { - authProvider.serverId = reconnect.id; - if (reconnect.oauthClientId) { - authProvider.clientId = reconnect.oauthClientId; - } - } + authProvider.serverId = id; // Use the transport type specified in options, or default to "auto" const transportType: TransportType = options?.transport?.type ?? "auto"; @@ -1585,9 +1438,12 @@ export class Agent< }; } - const { id, authUrl, clientId } = await this.mcp.connect(url, { + // Register server (also saves to storage) + this.mcp.registerServer(id, { + url, + name: serverName, + callbackUrl, client: options?.client, - reconnect, transport: { ...headerTransportOpts, authProvider, @@ -1595,35 +1451,23 @@ export class Agent< } }); + // Connect to server (updates storage with auth URL if OAuth) + const result = await this.mcp.connectToServer(id); + + this.broadcastMcpServers(); + return { - authUrl, - clientId, - id + id, + authUrl: result.authUrl }; } async removeMcpServer(id: string) { this.mcp.closeConnection(id); - this.mcp.unregisterCallbackUrl(id); - this.sql` - DELETE FROM cf_agents_mcp_servers WHERE id = ${id}; - `; + this.mcp.removeServer(id); this.broadcastMcpServers(); } - /** - * Clear the auth_url for an MCP server after successful OAuth authentication - * This prevents the agent from continuously asking for OAuth on reconnect - * @param id The server ID to clear auth_url for - */ - private clearMcpServerAuthUrl(id: string) { - this.sql` - UPDATE cf_agents_mcp_servers - SET auth_url = NULL - WHERE id = ${id} - `; - } - getMcpServers(): MCPServersState { const mcpState: MCPServersState = { prompts: this.mcp.listPrompts(), @@ -1632,21 +1476,26 @@ export class Agent< tools: this.mcp.listTools() }; - const servers = this.sql` - SELECT id, name, server_url, client_id, auth_url, callback_url, server_options FROM cf_agents_mcp_servers; - `; + const servers = this.mcp.listServers(); if (servers && Array.isArray(servers) && servers.length > 0) { for (const server of servers) { const serverConn = this.mcp.mcpConnections[server.id]; + + // Determine the default state when no connection exists + let defaultState: "authenticating" | "not-connected" = "not-connected"; + if (!serverConn && server.auth_url) { + // If there's an auth_url but no connection, it's waiting for OAuth + defaultState = "authenticating"; + } + mcpState.servers[server.id] = { auth_url: server.auth_url, capabilities: serverConn?.serverCapabilities ?? null, instructions: serverConn?.instructions ?? null, name: server.name, server_url: server.server_url, - // mark as "authenticating" because the server isn't automatically connected, so it's pending authenticating - state: serverConn?.connectionState ?? "authenticating" + state: serverConn?.connectionState ?? defaultState }; } } diff --git a/packages/agents/src/mcp/client-manager.ts b/packages/agents/src/mcp/client-manager.ts new file mode 100644 index 00000000..aef071f8 --- /dev/null +++ b/packages/agents/src/mcp/client-manager.ts @@ -0,0 +1,854 @@ +import type { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import type { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { + CallToolRequest, + CallToolResultSchema, + CompatibilityCallToolResultSchema, + GetPromptRequest, + Prompt, + ReadResourceRequest, + Resource, + ResourceTemplate, + Tool +} from "@modelcontextprotocol/sdk/types.js"; +import type { ToolSet } from "ai"; +import type { JSONSchema7 } from "json-schema"; +import { nanoid } from "nanoid"; +import { Emitter, type Event, DisposableStore } from "../core/events"; +import type { MCPObservabilityEvent } from "../observability/mcp"; +import { + MCPClientConnection, + type MCPTransportOptions +} from "./client-connection"; +import { toErrorMessage } from "./errors"; +import type { TransportType } from "./types"; +import type { MCPStorageAdapter, MCPServerRow } from "./client-storage"; +import type { AgentsOAuthProvider } from "./do-oauth-client-provider"; +import { DurableObjectOAuthClientProvider } from "./do-oauth-client-provider"; + +/** + * Options that can be stored in the server_options column + * This is what gets JSON.stringify'd and stored in the database + */ +export type MCPServerOptions = { + client?: ConstructorParameters[1]; + transport?: { + headers?: HeadersInit; + type?: TransportType; + }; +}; + +/** + * Options for registering an MCP server + */ +export type RegisterServerOptions = { + url: string; + name: string; + callbackUrl: string; + client?: ConstructorParameters[1]; + transport?: MCPTransportOptions; + authUrl?: string; + clientId?: string; +}; + +/** + * Options for connecting to an MCP server + */ +export type ConnectServerOptions = { + oauthCode?: string; +}; + +export type MCPClientOAuthCallbackConfig = { + successRedirect?: string; + errorRedirect?: string; + customHandler?: (result: MCPClientOAuthResult) => Response; +}; + +export type MCPClientOAuthResult = { + serverId: string; + authSuccess: boolean; + authError?: string; +}; + +export type MCPClientManagerOptions = { + storage: MCPStorageAdapter; +}; + +/** + * Utility class that aggregates multiple MCP clients into one + */ +export class MCPClientManager { + public mcpConnections: Record = {}; + private _didWarnAboutUnstableGetAITools = false; + private _oauthCallbackConfig?: MCPClientOAuthCallbackConfig; + private _connectionDisposables = new Map(); + private _storage: MCPStorageAdapter; + private _isRestored = false; + + // In-memory cache of callback URLs to avoid DB queries on every request + private _callbackUrlCache: Set | null = null; + + private readonly _onObservabilityEvent = new Emitter(); + public readonly onObservabilityEvent: Event = + this._onObservabilityEvent.event; + + private readonly _onConnected = new Emitter(); + public readonly onConnected: Event = this._onConnected.event; + + /** + * @param _name Name of the MCP client + * @param _version Version of the MCP Client + * @param options Storage adapter for persisting MCP server state + */ + constructor( + private _name: string, + private _version: string, + options: MCPClientManagerOptions + ) { + this._storage = options.storage; + + // Create the storage instance + this._storage.create(); + } + + jsonSchema: typeof import("ai").jsonSchema | undefined; + + /** + * Get the storage adapter instance + * @internal + */ + get storage(): MCPStorageAdapter { + return this._storage; + } + + /** + * Create an auth provider for a server + * @internal + */ + private createAuthProvider( + serverId: string, + callbackUrl: string, + clientName: string, + clientId?: string + ): AgentsOAuthProvider { + const authProvider = new DurableObjectOAuthClientProvider( + this._storage, + clientName, + callbackUrl + ); + authProvider.serverId = serverId; + if (clientId) { + authProvider.clientId = clientId; + } + return authProvider; + } + + /** + * Restore MCP server connections from storage + * This method is called on Agent initialization to restore previously connected servers + * + * @param clientName Name to use for OAuth client (typically the agent instance name) + */ + async restoreConnectionsFromStorage(clientName: string): Promise { + if (this._isRestored) { + return; + } + + const servers = this._storage.listServers(); + + if (!servers || servers.length === 0) { + this._isRestored = true; + return; + } + + for (const server of servers) { + const existingConn = this.mcpConnections[server.id]; + + // Skip if connection already exists and is in a good state + if (existingConn) { + if (existingConn.connectionState === "ready") { + console.warn( + `[MCPClientManager] Server ${server.id} already has a ready connection. Skipping recreation.` + ); + continue; + } + + // Don't interrupt in-flight OAuth or connections + if ( + existingConn.connectionState === "authenticating" || + existingConn.connectionState === "connecting" || + existingConn.connectionState === "discovering" + ) { + // Let the existing flow complete + continue; + } + + // If failed, we'll recreate below + } + + const needsOAuth = !!server.auth_url; + const parsedOptions: MCPServerOptions | null = server.server_options + ? JSON.parse(server.server_options) + : null; + + const authProvider = this.createAuthProvider( + server.id, + server.callback_url, + clientName, + server.client_id ?? undefined + ); + + this.registerServer(server.id, { + url: server.server_url, + name: server.name, + callbackUrl: server.callback_url, + client: parsedOptions?.client ?? {}, + transport: { + ...(parsedOptions?.transport ?? {}), + type: parsedOptions?.transport?.type ?? ("auto" as TransportType), + authProvider + }, + authUrl: server.auth_url ?? undefined, + clientId: server.client_id ?? undefined + }); + + if (needsOAuth) { + // OAuth server - just set state to authenticating (wait for OAuth flow) + if (this.mcpConnections[server.id]) { + this.mcpConnections[server.id].connectionState = "authenticating"; + } + } else { + // Non-OAuth server - connect immediately + await this.connectToServer(server.id).catch((error) => { + console.error(`Error restoring ${server.id}:`, error); + }); + } + } + + this._isRestored = true; + } + + /** + * Register an MCP server connection without connecting + * Creates the connection object, sets up observability, and saves to storage + * + * @param id Server ID + * @param options Registration options including URL, name, callback URL, and connection config + * @returns Server ID + */ + registerServer(id: string, options: RegisterServerOptions): string { + // Skip if connection already exists + if (this.mcpConnections[id]) { + return id; + } + + const normalizedTransport = { + ...options.transport, + type: options.transport?.type ?? ("auto" as TransportType) + }; + + this.mcpConnections[id] = new MCPClientConnection( + new URL(options.url), + { + name: this._name, + version: this._version + }, + { + client: options.client ?? {}, + transport: normalizedTransport + } + ); + + // Pipe connection-level observability events to the manager-level emitter + const store = new DisposableStore(); + const existing = this._connectionDisposables.get(id); + if (existing) existing.dispose(); + this._connectionDisposables.set(id, store); + store.add( + this.mcpConnections[id].onObservabilityEvent((event) => { + this._onObservabilityEvent.fire(event); + }) + ); + + // Save to storage + this._storage.saveServer({ + id, + name: options.name, + server_url: options.url, + callback_url: options.callbackUrl, + client_id: options.clientId ?? null, + auth_url: options.authUrl ?? null, + server_options: JSON.stringify({ + client: options.client, + transport: options.transport + }) + }); + + return id; + } + + /** + * Connect to an already registered MCP server + * Updates storage with auth URL and client ID after connection + * + * @param id Server ID + * @param options Connection options (e.g., OAuth code for completing OAuth flow) + * @returns Auth URL if OAuth is required, undefined otherwise + */ + async connectToServer( + id: string, + options?: ConnectServerOptions + ): Promise<{ + authUrl?: string; + clientId?: string; + }> { + const conn = this.mcpConnections[id]; + if (!conn) { + throw new Error( + `Server ${id} is not registered. Call registerServer() first.` + ); + } + + // Handle OAuth completion if we have a code + if (options?.oauthCode) { + try { + await conn.completeAuthorization(options.oauthCode); + await conn.establishConnection(); + } catch (error) { + this._onObservabilityEvent.fire({ + type: "mcp:client:connect", + displayMessage: `Failed to complete OAuth reconnection for ${id}`, + payload: { + url: conn.url.toString(), + transport: conn.options.transport.type ?? "auto", + state: conn.connectionState, + error: toErrorMessage(error) + }, + timestamp: Date.now(), + id + }); + throw error; + } + return {}; + } + + // Initialize connection + await conn.init(); + + // If connection is in authenticating state, return auth URL for OAuth flow + const authUrl = conn.options.transport.authProvider?.authUrl; + + if ( + conn.connectionState === "authenticating" && + authUrl && + conn.options.transport.authProvider?.redirectUrl + ) { + const clientId = conn.options.transport.authProvider?.clientId; + + // Update storage with auth URL and client ID + const serverRow = this._storage.listServers().find((s) => s.id === id); + if (serverRow) { + this._storage.saveServer({ + ...serverRow, + auth_url: authUrl, + client_id: clientId ?? null + }); + } + + return { + authUrl, + clientId + }; + } + + // Fire connected event for non-OAuth connections that reached ready state + if (conn.connectionState === "ready") { + this._onConnected.fire(id); + } + + return {}; + } + + /** + * Refresh the in-memory callback URL cache from storage + */ + private async _refreshCallbackUrlCache(): Promise { + const servers = this._storage.listServers(); + this._callbackUrlCache = new Set( + servers.filter((s) => s.callback_url).map((s) => s.callback_url) + ); + } + + /** + * Invalidate the callback URL cache so it will be refreshed on next check + */ + private _invalidateCallbackUrlCache(): void { + this._callbackUrlCache = null; + } + + async isCallbackRequest(req: Request): Promise { + if (req.method !== "GET") { + return false; + } + + // Quick heuristic check: most callback URLs contain "/callback" + // This avoids DB queries for obviously non-callback requests + if (!req.url.includes("/callback")) { + return false; + } + + // Lazily populate cache on first check + if (this._callbackUrlCache === null) { + await this._refreshCallbackUrlCache(); + } + + // Check cache first for quick lookup + for (const callbackUrl of this._callbackUrlCache!) { + if (req.url.startsWith(callbackUrl)) { + return true; + } + } + + return false; + } + + async handleCallbackRequest(req: Request) { + const url = new URL(req.url); + + // Find the matching server from database + const servers = this._storage.listServers(); + const matchingServer = servers.find((server: MCPServerRow) => { + return server.callback_url && req.url.startsWith(server.callback_url); + }); + + if (!matchingServer) { + throw new Error( + `No callback URI match found for the request url: ${req.url}. Was the request matched with \`isCallbackRequest()\`?` + ); + } + + const serverId = matchingServer.id; + const code = url.searchParams.get("code"); + const state = url.searchParams.get("state"); + const error = url.searchParams.get("error"); + const errorDescription = url.searchParams.get("error_description"); + + // Handle OAuth error responses from the provider + if (error) { + return { + serverId, + authSuccess: false, + authError: errorDescription || error + }; + } + + if (!code) { + throw new Error("Unauthorized: no code provided"); + } + if (!state) { + throw new Error("Unauthorized: no state provided"); + } + + if (this.mcpConnections[serverId] === undefined) { + throw new Error(`Could not find serverId: ${serverId}`); + } + + // If connection is already ready, this is likely a duplicate callback + if (this.mcpConnections[serverId].connectionState === "ready") { + // Already authenticated and ready, treat as success + return { + serverId, + authSuccess: true + }; + } + + if (this.mcpConnections[serverId].connectionState !== "authenticating") { + throw new Error( + `Failed to authenticate: the client is in "${this.mcpConnections[serverId].connectionState}" state, expected "authenticating"` + ); + } + + const conn = this.mcpConnections[serverId]; + if (!conn.options.transport.authProvider) { + throw new Error( + "Trying to finalize authentication for a server connection without an authProvider" + ); + } + + // Get clientId from auth provider (stored during redirectToAuthorization) or fallback to state for backward compatibility + const clientId = conn.options.transport.authProvider.clientId || state; + + // Set the OAuth credentials + conn.options.transport.authProvider.clientId = clientId; + conn.options.transport.authProvider.serverId = serverId; + + try { + await conn.completeAuthorization(code); + this._storage.clearOAuthCredentials(serverId); + this._invalidateCallbackUrlCache(); + + return { + serverId, + authSuccess: true + }; + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + + return { + serverId, + authSuccess: false, + authError: errorMessage + }; + } + } + + /** + * Establish connection in the background after OAuth completion + * This method is called asynchronously and doesn't block the OAuth callback response + * @param serverId The server ID to establish connection for + */ + async establishConnection(serverId: string): Promise { + const conn = this.mcpConnections[serverId]; + if (!conn) { + this._onObservabilityEvent.fire({ + type: "mcp:client:preconnect", + displayMessage: `Connection not found for serverId: ${serverId}`, + payload: { serverId }, + timestamp: Date.now(), + id: nanoid() + }); + return; + } + + try { + await conn.establishConnection(); + this._onConnected.fire(serverId); + } catch (error) { + const url = conn.url.toString(); + this._onObservabilityEvent.fire({ + type: "mcp:client:connect", + displayMessage: `Failed to establish connection to server ${serverId} with url ${url}`, + payload: { + url, + transport: conn.options.transport.type ?? "auto", + state: conn.connectionState, + error: toErrorMessage(error) + }, + timestamp: Date.now(), + id: nanoid() + }); + } + } + + /** + * Configure OAuth callback handling + * @param config OAuth callback configuration + */ + configureOAuthCallback(config: MCPClientOAuthCallbackConfig): void { + this._oauthCallbackConfig = config; + } + + /** + * Get the current OAuth callback configuration + * @returns The current OAuth callback configuration + */ + getOAuthCallbackConfig(): MCPClientOAuthCallbackConfig | undefined { + return this._oauthCallbackConfig; + } + + /** + * @returns namespaced list of tools + */ + listTools(): NamespacedData["tools"] { + return getNamespacedData(this.mcpConnections, "tools"); + } + + async ensureJsonSchema() { + if (!this.jsonSchema) { + const { jsonSchema } = await import("ai"); + this.jsonSchema = jsonSchema; + } + } + + /** + * Check if all MCP connections are in a stable state (ready or authenticating) + * Useful to call before getAITools() to avoid race conditions + * + * @returns Object with ready status and list of connections not ready + */ + areConnectionsReady(): { ready: boolean; pendingConnections: string[] } { + const pendingConnections: string[] = []; + + for (const [id, conn] of Object.entries(this.mcpConnections)) { + if ( + conn.connectionState !== "ready" && + conn.connectionState !== "authenticating" + ) { + pendingConnections.push(id); + } + } + + return { + ready: pendingConnections.length === 0, + pendingConnections + }; + } + + /** + * @returns a set of tools that you can use with the AI SDK + */ + getAITools(): ToolSet { + if (!this.jsonSchema) { + throw new Error("jsonSchema not initialized."); + } + + // Warn if tools are being read from non-ready connections + for (const [id, conn] of Object.entries(this.mcpConnections)) { + if ( + conn.connectionState !== "ready" && + conn.connectionState !== "authenticating" + ) { + console.warn( + `[getAITools] WARNING: Reading tools from connection ${id} in state "${conn.connectionState}". Tools may not be loaded yet.` + ); + } + } + + return Object.fromEntries( + getNamespacedData(this.mcpConnections, "tools").map((tool) => { + return [ + `tool_${tool.serverId.replace(/-/g, "")}_${tool.name}`, + { + description: tool.description, + execute: async (args) => { + const result = await this.callTool({ + arguments: args, + name: tool.name, + serverId: tool.serverId + }); + if (result.isError) { + // @ts-expect-error TODO we should fix this + throw new Error(result.content[0].text); + } + return result; + }, + inputSchema: this.jsonSchema!(tool.inputSchema as JSONSchema7), + outputSchema: tool.outputSchema + ? this.jsonSchema!(tool.outputSchema as JSONSchema7) + : undefined + } + ]; + }) + ); + } + + /** + * @deprecated this has been renamed to getAITools(), and unstable_getAITools will be removed in the next major version + * @returns a set of tools that you can use with the AI SDK + */ + unstable_getAITools(): ToolSet { + if (!this._didWarnAboutUnstableGetAITools) { + this._didWarnAboutUnstableGetAITools = true; + console.warn( + "unstable_getAITools is deprecated, use getAITools instead. unstable_getAITools will be removed in the next major version." + ); + } + return this.getAITools(); + } + + /** + * Closes all connections to MCP servers + */ + async closeAllConnections() { + const ids = Object.keys(this.mcpConnections); + await Promise.all( + ids.map(async (id) => { + await this.mcpConnections[id].client.close(); + }) + ); + // Dispose all per-connection subscriptions + for (const id of ids) { + const store = this._connectionDisposables.get(id); + if (store) store.dispose(); + this._connectionDisposables.delete(id); + delete this.mcpConnections[id]; + } + } + + /** + * Closes a connection to an MCP server + * @param id The id of the connection to close + */ + async closeConnection(id: string) { + if (!this.mcpConnections[id]) { + throw new Error(`Connection with id "${id}" does not exist.`); + } + await this.mcpConnections[id].client.close(); + delete this.mcpConnections[id]; + + const store = this._connectionDisposables.get(id); + if (store) store.dispose(); + this._connectionDisposables.delete(id); + } + + /** + * Save an MCP server configuration to storage + */ + saveServer(server: { + id: string; + name: string; + server_url: string; + client_id?: string | null; + auth_url?: string | null; + callback_url: string; + server_options?: string | null; + }): void { + if (this._storage) { + this._storage.saveServer({ + id: server.id, + name: server.name, + server_url: server.server_url, + client_id: server.client_id ?? null, + auth_url: server.auth_url ?? null, + callback_url: server.callback_url, + server_options: server.server_options ?? null + }); + // Invalidate cache since callback URLs may have changed + this._invalidateCallbackUrlCache(); + } + } + + /** + * Remove an MCP server from storage + */ + removeServer(serverId: string): void { + if (this._storage) { + this._storage.removeServer(serverId); + // Invalidate cache since callback URLs may have changed + this._invalidateCallbackUrlCache(); + } + } + + /** + * List all MCP servers from storage + */ + listServers() { + if (this._storage) { + return this._storage.listServers(); + } + return []; + } + + /** + * Dispose the manager and all resources. + */ + async dispose(): Promise { + try { + await this.closeAllConnections(); + } finally { + // Dispose manager-level emitters + this._onConnected.dispose(); + this._onObservabilityEvent.dispose(); + + // Drop the storage table + this._storage.destroy(); + } + } + + /** + * @returns namespaced list of prompts + */ + listPrompts(): NamespacedData["prompts"] { + return getNamespacedData(this.mcpConnections, "prompts"); + } + + /** + * @returns namespaced list of tools + */ + listResources(): NamespacedData["resources"] { + return getNamespacedData(this.mcpConnections, "resources"); + } + + /** + * @returns namespaced list of resource templates + */ + listResourceTemplates(): NamespacedData["resourceTemplates"] { + return getNamespacedData(this.mcpConnections, "resourceTemplates"); + } + + /** + * Namespaced version of callTool + */ + async callTool( + params: CallToolRequest["params"] & { serverId: string }, + resultSchema?: + | typeof CallToolResultSchema + | typeof CompatibilityCallToolResultSchema, + options?: RequestOptions + ) { + const unqualifiedName = params.name.replace(`${params.serverId}.`, ""); + return this.mcpConnections[params.serverId].client.callTool( + { + ...params, + name: unqualifiedName + }, + resultSchema, + options + ); + } + + /** + * Namespaced version of readResource + */ + readResource( + params: ReadResourceRequest["params"] & { serverId: string }, + options: RequestOptions + ) { + return this.mcpConnections[params.serverId].client.readResource( + params, + options + ); + } + + /** + * Namespaced version of getPrompt + */ + getPrompt( + params: GetPromptRequest["params"] & { serverId: string }, + options: RequestOptions + ) { + return this.mcpConnections[params.serverId].client.getPrompt( + params, + options + ); + } +} + +type NamespacedData = { + tools: (Tool & { serverId: string })[]; + prompts: (Prompt & { serverId: string })[]; + resources: (Resource & { serverId: string })[]; + resourceTemplates: (ResourceTemplate & { serverId: string })[]; +}; + +export function getNamespacedData( + mcpClients: Record, + type: T +): NamespacedData[T] { + const sets = Object.entries(mcpClients).map(([name, conn]) => { + return { data: conn[type], name }; + }); + + const namespacedData = sets.flatMap(({ name: serverId, data }) => { + return data.map((item) => { + return { + ...item, + // we add a serverId so we can easily pull it out and send the tool call to the right server + serverId + }; + }); + }); + + return namespacedData as NamespacedData[T]; // Type assertion needed due to TS limitations with conditional return types +} diff --git a/packages/agents/src/mcp/client-storage.ts b/packages/agents/src/mcp/client-storage.ts new file mode 100644 index 00000000..54e46880 --- /dev/null +++ b/packages/agents/src/mcp/client-storage.ts @@ -0,0 +1,168 @@ +/** + * Represents a row in the cf_agents_mcp_servers table + */ +export type MCPServerRow = { + id: string; + name: string; + server_url: string; + client_id: string | null; + auth_url: string | null; + callback_url: string; + server_options: string | null; +}; + +/** + * Storage adapter interface for MCP client manager + * Abstracts SQL operations to decouple from specific storage implementations + */ +export interface MCPStorageAdapter { + /** + * Create the cf_agents_mcp_servers table if it doesn't exist + */ + create(): void; + + /** + * Drop the cf_agents_mcp_servers table + */ + destroy(): void; + + /** + * Save or update an MCP server configuration + */ + saveServer(server: MCPServerRow): void; + + /** + * Remove an MCP server from storage + */ + removeServer(serverId: string): void; + + /** + * List all MCP servers from storage + */ + listServers(): MCPServerRow[]; + + /** + * Get an MCP server by its callback URL + * Used during OAuth callback to identify which server is being authenticated + */ + getServerByCallbackUrl(callbackUrl: string): MCPServerRow | null; + + /** + * Clear both auth_url and callback_url after successful OAuth authentication + * This prevents the agent from continuously asking for OAuth on reconnect + * and prevents malicious second callbacks from being processed + */ + clearOAuthCredentials(serverId: string): void; + + /** + * Get a value from key-value storage (for OAuth data like tokens, client info, etc.) + */ + get(key: string): T | undefined; + + /** + * Put a value into key-value storage (for OAuth data like tokens, client info, etc.) + */ + put(key: string, value: unknown): void; +} + +/** + * SQL-based storage adapter that wraps SQL operations + * Used by Agent class to provide SQL access to MCPClientManager + */ +export class AgentMCPStorageAdapter implements MCPStorageAdapter { + constructor( + private sql: >( + strings: TemplateStringsArray, + ...values: (string | number | boolean | null)[] + ) => T[], + private kv: SyncKvStorage + ) {} + + create() { + this.sql` + CREATE TABLE IF NOT EXISTS cf_agents_mcp_servers ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + server_url TEXT NOT NULL, + callback_url TEXT NOT NULL, + client_id TEXT, + auth_url TEXT, + server_options TEXT + ) + `; + } + + destroy() { + this.sql`DROP TABLE IF EXISTS cf_agents_mcp_servers`; + } + + saveServer(server: MCPServerRow) { + this.sql` + INSERT OR REPLACE INTO cf_agents_mcp_servers ( + id, + name, + server_url, + client_id, + auth_url, + callback_url, + server_options + ) + VALUES ( + ${server.id}, + ${server.name}, + ${server.server_url}, + ${server.client_id ?? null}, + ${server.auth_url ?? null}, + ${server.callback_url}, + ${server.server_options ?? null} + ) + `; + } + + removeServer(serverId: string) { + this.sql` + DELETE FROM cf_agents_mcp_servers WHERE id = ${serverId} + `; + } + + listServers() { + const servers = this.sql` + SELECT id, name, server_url, client_id, auth_url, callback_url, server_options + FROM cf_agents_mcp_servers + `; + return servers; + } + + getServerByCallbackUrl(callbackUrl: string) { + const results = this.sql` + SELECT id, name, server_url, client_id, auth_url, callback_url, server_options + FROM cf_agents_mcp_servers + WHERE callback_url = ${callbackUrl} + LIMIT 1 + `; + return results.length > 0 ? results[0] : null; + } + + clearOAuthCredentials(serverId: string) { + this.sql` + UPDATE cf_agents_mcp_servers + SET callback_url = '', auth_url = NULL + WHERE id = ${serverId} + `; + } + + get(key: string) { + return this.kv.get(key); + } + + put( + key: string, + value: + | string + | ArrayBuffer + | ArrayBufferView + | ReadableStream + ) { + this.kv.put(key, value); + } +} diff --git a/packages/agents/src/mcp/client.ts b/packages/agents/src/mcp/client.ts deleted file mode 100644 index 99d2059f..00000000 --- a/packages/agents/src/mcp/client.ts +++ /dev/null @@ -1,555 +0,0 @@ -import type { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import type { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; -import type { - CallToolRequest, - CallToolResultSchema, - CompatibilityCallToolResultSchema, - GetPromptRequest, - Prompt, - ReadResourceRequest, - Resource, - ResourceTemplate, - Tool -} from "@modelcontextprotocol/sdk/types.js"; -import type { ToolSet } from "ai"; -import type { JSONSchema7 } from "json-schema"; -import { nanoid } from "nanoid"; -import { Emitter, type Event, DisposableStore } from "../core/events"; -import type { MCPObservabilityEvent } from "../observability/mcp"; -import { - MCPClientConnection, - type MCPTransportOptions -} from "./client-connection"; -import { toErrorMessage } from "./errors"; -import type { TransportType } from "./types"; - -export type MCPClientOAuthCallbackConfig = { - successRedirect?: string; - errorRedirect?: string; - customHandler?: (result: MCPClientOAuthResult) => Response; -}; - -export type MCPClientOAuthResult = { - serverId: string; - authSuccess: boolean; - authError?: string; -}; - -/** - * Utility class that aggregates multiple MCP clients into one - */ -export class MCPClientManager { - public mcpConnections: Record = {}; - private _callbackUrls: string[] = []; - private _didWarnAboutUnstableGetAITools = false; - private _oauthCallbackConfig?: MCPClientOAuthCallbackConfig; - private _connectionDisposables = new Map(); - - private readonly _onObservabilityEvent = new Emitter(); - public readonly onObservabilityEvent: Event = - this._onObservabilityEvent.event; - - private readonly _onConnected = new Emitter(); - public readonly onConnected: Event = this._onConnected.event; - - /** - * @param _name Name of the MCP client - * @param _version Version of the MCP Client - * @param auth Auth paramters if being used to create a DurableObjectOAuthClientProvider - */ - constructor( - private _name: string, - private _version: string - ) {} - - jsonSchema: typeof import("ai").jsonSchema | undefined; - - /** - * Connect to and register an MCP server - * - * @param transportConfig Transport config - * @param clientConfig Client config - * @param capabilities Client capabilities (i.e. if the client supports roots/sampling) - */ - async connect( - url: string, - options: { - // Allows you to reconnect to a server (in the case of an auth reconnect) - reconnect?: { - // server id - id: string; - oauthClientId?: string; - oauthCode?: string; - }; - // we're overriding authProvider here because we want to be able to access the auth URL - transport?: MCPTransportOptions; - client?: ConstructorParameters[1]; - } = {} - ): Promise<{ - id: string; - authUrl?: string; - clientId?: string; - }> { - /* Late initialization of jsonSchemaFn */ - /** - * We need to delay loading ai sdk, because putting it in module scope is - * causing issues with startup time. - * The only place it's used is in getAITools, which only matters after - * .connect() is called on at least one server. - * So it's safe to delay loading it until .connect() is called. - */ - if (!this.jsonSchema) { - const { jsonSchema } = await import("ai"); - this.jsonSchema = jsonSchema; - } - - const id = options.reconnect?.id ?? nanoid(8); - - if (options.transport?.authProvider) { - options.transport.authProvider.serverId = id; - // reconnect with auth - if (options.reconnect?.oauthClientId) { - options.transport.authProvider.clientId = - options.reconnect?.oauthClientId; - } - } - - // During OAuth reconnect, reuse existing connection to preserve state - if (!options.reconnect?.oauthCode || !this.mcpConnections[id]) { - const normalizedTransport = { - ...options.transport, - type: options.transport?.type ?? ("auto" as TransportType) - }; - - this.mcpConnections[id] = new MCPClientConnection( - new URL(url), - { - name: this._name, - version: this._version - }, - { - client: options.client ?? {}, - transport: normalizedTransport - } - ); - - // Pipe connection-level observability events to the manager-level emitter - // and track the subscription for cleanup. - const store = new DisposableStore(); - // If we somehow already had disposables for this id, clear them first - const existing = this._connectionDisposables.get(id); - if (existing) existing.dispose(); - this._connectionDisposables.set(id, store); - store.add( - this.mcpConnections[id].onObservabilityEvent((event) => { - this._onObservabilityEvent.fire(event); - }) - ); - } - - // Initialize connection first - await this.mcpConnections[id].init(); - - // Handle OAuth completion if we have a reconnect code - if (options.reconnect?.oauthCode) { - try { - await this.mcpConnections[id].completeAuthorization( - options.reconnect.oauthCode - ); - await this.mcpConnections[id].establishConnection(); - } catch (error) { - this._onObservabilityEvent.fire({ - type: "mcp:client:connect", - displayMessage: `Failed to complete OAuth reconnection for ${id} for ${url}`, - payload: { - url: url, - transport: options.transport?.type ?? "auto", - state: this.mcpConnections[id].connectionState, - error: toErrorMessage(error) - }, - timestamp: Date.now(), - id - }); - // Re-throw to signal failure to the caller - throw error; - } - } - - // If connection is in authenticating state, return auth URL for OAuth flow - const authUrl = options.transport?.authProvider?.authUrl; - if ( - this.mcpConnections[id].connectionState === "authenticating" && - authUrl && - options.transport?.authProvider?.redirectUrl - ) { - this._callbackUrls.push( - options.transport.authProvider.redirectUrl.toString() - ); - return { - authUrl, - clientId: options.transport?.authProvider?.clientId, - id - }; - } - - return { - id - }; - } - - isCallbackRequest(req: Request): boolean { - return ( - req.method === "GET" && - !!this._callbackUrls.find((url) => { - return req.url.startsWith(url); - }) - ); - } - - async handleCallbackRequest(req: Request) { - const url = new URL(req.url); - const urlMatch = this._callbackUrls.find((url) => { - return req.url.startsWith(url); - }); - if (!urlMatch) { - throw new Error( - `No callback URI match found for the request url: ${req.url}. Was the request matched with \`isCallbackRequest()\`?` - ); - } - const code = url.searchParams.get("code"); - const state = url.searchParams.get("state"); - const error = url.searchParams.get("error"); - const errorDescription = url.searchParams.get("error_description"); - const urlParams = urlMatch.split("/"); - const serverId = urlParams[urlParams.length - 1]; - - // Handle OAuth error responses from the provider - if (error) { - return { - serverId, - authSuccess: false, - authError: errorDescription || error - }; - } - - if (!code) { - throw new Error("Unauthorized: no code provided"); - } - if (!state) { - throw new Error("Unauthorized: no state provided"); - } - - if (this.mcpConnections[serverId] === undefined) { - throw new Error(`Could not find serverId: ${serverId}`); - } - - if (this.mcpConnections[serverId].connectionState !== "authenticating") { - throw new Error( - "Failed to authenticate: the client isn't in the `authenticating` state" - ); - } - - const conn = this.mcpConnections[serverId]; - if (!conn.options.transport.authProvider) { - throw new Error( - "Trying to finalize authentication for a server connection without an authProvider" - ); - } - - // Get clientId from auth provider (stored during redirectToAuthorization) or fallback to state for backward compatibility - const clientId = conn.options.transport.authProvider.clientId || state; - - // Set the OAuth credentials - conn.options.transport.authProvider.clientId = clientId; - conn.options.transport.authProvider.serverId = serverId; - - try { - await conn.completeAuthorization(code); - return { - serverId, - authSuccess: true - }; - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - - return { - serverId, - authSuccess: false, - authError: errorMessage - }; - } - } - - /** - * Establish connection in the background after OAuth completion - * This method is called asynchronously and doesn't block the OAuth callback response - * @param serverId The server ID to establish connection for - */ - async establishConnection(serverId: string): Promise { - const conn = this.mcpConnections[serverId]; - if (!conn) { - this._onObservabilityEvent.fire({ - type: "mcp:client:preconnect", - displayMessage: `Connection not found for serverId: ${serverId}`, - payload: { serverId }, - timestamp: Date.now(), - id: nanoid() - }); - return; - } - - try { - await conn.establishConnection(); - this._onConnected.fire(serverId); - } catch (error) { - const url = conn.url.toString(); - this._onObservabilityEvent.fire({ - type: "mcp:client:connect", - displayMessage: `Failed to establish connection to server ${serverId} with url ${url}`, - payload: { - url, - transport: conn.options.transport.type ?? "auto", - state: conn.connectionState, - error: toErrorMessage(error) - }, - timestamp: Date.now(), - id: nanoid() - }); - } - } - - /** - * Register a callback URL for OAuth handling - * @param url The callback URL to register - */ - registerCallbackUrl(url: string): void { - if (!this._callbackUrls.includes(url)) { - this._callbackUrls.push(url); - } - } - - /** - * Unregister a callback URL - * @param serverId The server ID whose callback URL should be removed - */ - unregisterCallbackUrl(serverId: string): void { - // Remove callback URLs that end with this serverId - this._callbackUrls = this._callbackUrls.filter( - (url) => !url.endsWith(`/${serverId}`) - ); - } - - /** - * Configure OAuth callback handling - * @param config OAuth callback configuration - */ - configureOAuthCallback(config: MCPClientOAuthCallbackConfig): void { - this._oauthCallbackConfig = config; - } - - /** - * Get the current OAuth callback configuration - * @returns The current OAuth callback configuration - */ - getOAuthCallbackConfig(): MCPClientOAuthCallbackConfig | undefined { - return this._oauthCallbackConfig; - } - - /** - * @returns namespaced list of tools - */ - listTools(): NamespacedData["tools"] { - return getNamespacedData(this.mcpConnections, "tools"); - } - - /** - * @returns a set of tools that you can use with the AI SDK - */ - getAITools(): ToolSet { - return Object.fromEntries( - getNamespacedData(this.mcpConnections, "tools").map((tool) => { - return [ - `tool_${tool.serverId.replace(/-/g, "")}_${tool.name}`, - { - description: tool.description, - execute: async (args) => { - const result = await this.callTool({ - arguments: args, - name: tool.name, - serverId: tool.serverId - }); - if (result.isError) { - // @ts-expect-error TODO we should fix this - throw new Error(result.content[0].text); - } - return result; - }, - inputSchema: this.jsonSchema!(tool.inputSchema as JSONSchema7), - outputSchema: tool.outputSchema - ? this.jsonSchema!(tool.outputSchema as JSONSchema7) - : undefined - } - ]; - }) - ); - } - - /** - * @deprecated this has been renamed to getAITools(), and unstable_getAITools will be removed in the next major version - * @returns a set of tools that you can use with the AI SDK - */ - unstable_getAITools(): ToolSet { - if (!this._didWarnAboutUnstableGetAITools) { - this._didWarnAboutUnstableGetAITools = true; - console.warn( - "unstable_getAITools is deprecated, use getAITools instead. unstable_getAITools will be removed in the next major version." - ); - } - return this.getAITools(); - } - - /** - * Closes all connections to MCP servers - */ - async closeAllConnections() { - const ids = Object.keys(this.mcpConnections); - await Promise.all( - ids.map(async (id) => { - await this.mcpConnections[id].client.close(); - }) - ); - // Dispose all per-connection subscriptions - for (const id of ids) { - const store = this._connectionDisposables.get(id); - if (store) store.dispose(); - this._connectionDisposables.delete(id); - delete this.mcpConnections[id]; - } - } - - /** - * Closes a connection to an MCP server - * @param id The id of the connection to close - */ - async closeConnection(id: string) { - if (!this.mcpConnections[id]) { - throw new Error(`Connection with id "${id}" does not exist.`); - } - await this.mcpConnections[id].client.close(); - delete this.mcpConnections[id]; - - const store = this._connectionDisposables.get(id); - if (store) store.dispose(); - this._connectionDisposables.delete(id); - } - - /** - * Dispose the manager and all resources. - */ - async dispose(): Promise { - try { - await this.closeAllConnections(); - } finally { - // Dispose manager-level emitters - this._onConnected.dispose(); - this._onObservabilityEvent.dispose(); - } - } - - /** - * @returns namespaced list of prompts - */ - listPrompts(): NamespacedData["prompts"] { - return getNamespacedData(this.mcpConnections, "prompts"); - } - - /** - * @returns namespaced list of tools - */ - listResources(): NamespacedData["resources"] { - return getNamespacedData(this.mcpConnections, "resources"); - } - - /** - * @returns namespaced list of resource templates - */ - listResourceTemplates(): NamespacedData["resourceTemplates"] { - return getNamespacedData(this.mcpConnections, "resourceTemplates"); - } - - /** - * Namespaced version of callTool - */ - async callTool( - params: CallToolRequest["params"] & { serverId: string }, - resultSchema?: - | typeof CallToolResultSchema - | typeof CompatibilityCallToolResultSchema, - options?: RequestOptions - ) { - const unqualifiedName = params.name.replace(`${params.serverId}.`, ""); - return this.mcpConnections[params.serverId].client.callTool( - { - ...params, - name: unqualifiedName - }, - resultSchema, - options - ); - } - - /** - * Namespaced version of readResource - */ - readResource( - params: ReadResourceRequest["params"] & { serverId: string }, - options: RequestOptions - ) { - return this.mcpConnections[params.serverId].client.readResource( - params, - options - ); - } - - /** - * Namespaced version of getPrompt - */ - getPrompt( - params: GetPromptRequest["params"] & { serverId: string }, - options: RequestOptions - ) { - return this.mcpConnections[params.serverId].client.getPrompt( - params, - options - ); - } -} - -type NamespacedData = { - tools: (Tool & { serverId: string })[]; - prompts: (Prompt & { serverId: string })[]; - resources: (Resource & { serverId: string })[]; - resourceTemplates: (ResourceTemplate & { serverId: string })[]; -}; - -export function getNamespacedData( - mcpClients: Record, - type: T -): NamespacedData[T] { - const sets = Object.entries(mcpClients).map(([name, conn]) => { - return { data: conn[type], name }; - }); - - const namespacedData = sets.flatMap(({ name: serverId, data }) => { - return data.map((item) => { - return { - ...item, - // we add a serverId so we can easily pull it out and send the tool call to the right server - serverId - }; - }); - }); - - return namespacedData as NamespacedData[T]; // Type assertion needed due to TS limitations with conditional return types -} diff --git a/packages/agents/src/mcp/do-oauth-client-provider.ts b/packages/agents/src/mcp/do-oauth-client-provider.ts index 92b679b2..6e366a07 100644 --- a/packages/agents/src/mcp/do-oauth-client-provider.ts +++ b/packages/agents/src/mcp/do-oauth-client-provider.ts @@ -6,6 +6,7 @@ import type { OAuthTokens } from "@modelcontextprotocol/sdk/shared/auth.js"; import { nanoid } from "nanoid"; +import type { MCPStorageAdapter } from "./client-storage"; // A slight extension to the standard OAuthClientProvider interface because `redirectToAuthorization` doesn't give us the interface we need // This allows us to track authentication for a specific server and associated dynamic client registration @@ -21,7 +22,7 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider { private _clientId_: string | undefined; constructor( - public storage: DurableObjectStorage, + public storage: MCPStorageAdapter, public clientName: string, public baseRedirectUrl: string ) {} @@ -75,21 +76,17 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider { return `${this.keyPrefix(clientId)}/client_info/`; } - async clientInformation(): Promise { + clientInformation() { if (!this._clientId_) { return undefined; } - return ( - (await this.storage.get( - this.clientInfoKey(this.clientId) - )) ?? undefined + return this.storage.get( + this.clientInfoKey(this.clientId) ); } - async saveClientInformation( - clientInformation: OAuthClientInformationFull - ): Promise { - await this.storage.put( + saveClientInformation(clientInformation: OAuthClientInformationFull) { + this.storage.put( this.clientInfoKey(clientInformation.client_id), clientInformation ); @@ -100,18 +97,15 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider { return `${this.keyPrefix(clientId)}/token`; } - async tokens(): Promise { + tokens() { if (!this._clientId_) { return undefined; } - return ( - (await this.storage.get(this.tokenKey(this.clientId))) ?? - undefined - ); + return this.storage.get(this.tokenKey(this.clientId)); } - async saveTokens(tokens: OAuthTokens): Promise { - await this.storage.put(this.tokenKey(this.clientId), tokens); + saveTokens(tokens: OAuthTokens) { + this.storage.put(this.tokenKey(this.clientId), tokens); } get authUrl() { @@ -122,7 +116,7 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider { * Because this operates on the server side (but we need browser auth), we send this url back to the user * and require user interact to initiate the redirect flow */ - async redirectToAuthorization(authUrl: URL): Promise { + redirectToAuthorization(authUrl: URL) { // Generate secure random token for state parameter const stateToken = nanoid(); authUrl.searchParams.set("state", stateToken); @@ -133,20 +127,20 @@ export class DurableObjectOAuthClientProvider implements AgentsOAuthProvider { return `${this.keyPrefix(clientId)}/code_verifier`; } - async saveCodeVerifier(verifier: string): Promise { + saveCodeVerifier(verifier: string) { const key = this.codeVerifierKey(this.clientId); // Don't overwrite existing verifier to preserve first PKCE verifier - const existing = await this.storage.get(key); + const existing = this.storage.get(key); if (existing) { return; } - await this.storage.put(key, verifier); + this.storage.put(key, verifier); } - async codeVerifier(): Promise { - const codeVerifier = await this.storage.get( + codeVerifier(): string { + const codeVerifier = this.storage.get( this.codeVerifierKey(this.clientId) ); if (!codeVerifier) { diff --git a/packages/agents/src/mcp/index.ts b/packages/agents/src/mcp/index.ts index d7b8e629..534a42bd 100644 --- a/packages/agents/src/mcp/index.ts +++ b/packages/agents/src/mcp/index.ts @@ -446,8 +446,9 @@ export { export type { MCPClientOAuthResult, - MCPClientOAuthCallbackConfig -} from "./client"; + MCPClientOAuthCallbackConfig, + MCPServerOptions +} from "./client-manager"; export { createMcpHandler, diff --git a/packages/agents/src/tests/mcp/client-manager.test.ts b/packages/agents/src/tests/mcp/client-manager.test.ts index d9a3366d..3a0baa66 100644 --- a/packages/agents/src/tests/mcp/client-manager.test.ts +++ b/packages/agents/src/tests/mcp/client-manager.test.ts @@ -1,12 +1,93 @@ import { describe, it, expect, beforeEach, vi } from "vitest"; -import { MCPClientManager } from "../../mcp/client"; +import { MCPClientManager } from "../../mcp/client-manager"; import { MCPClientConnection } from "../../mcp/client-connection"; +import { + AgentMCPStorageAdapter, + type MCPServerRow +} from "../../mcp/client-storage"; +import type { ToolCallOptions } from "ai"; describe("MCPClientManager OAuth Integration", () => { let manager: MCPClientManager; + let mockStorageData: Map; + let mockKVData: Map; beforeEach(() => { - manager = new MCPClientManager("test-client", "1.0.0"); + mockStorageData = new Map(); + mockKVData = new Map(); + + // Create a proper mock storage adapter + const mockStorage = new AgentMCPStorageAdapter( + >( + strings: TemplateStringsArray, + ...values: (string | number | boolean | null)[] + ) => { + const query = strings.join(""); + + if (query.includes("INSERT OR REPLACE")) { + const id = values[0] as string; + mockStorageData.set(id, { + id: values[0] as string, + name: values[1] as string, + server_url: values[2] as string, + client_id: values[3] as string | null, + auth_url: values[4] as string | null, + callback_url: values[5] as string, + server_options: values[6] as string | null + }); + return [] as unknown as T[]; + } + + if (query.includes("DELETE")) { + const id = values[0] as string; + mockStorageData.delete(id); + return [] as unknown as T[]; + } + + if ( + query.includes("UPDATE") && + query.includes("callback_url = ''") && + query.includes("auth_url = NULL") + ) { + // Combined clearOAuthCredentials query + const id = values[0] as string; + const server = mockStorageData.get(id); + if (server) { + server.callback_url = ""; + server.auth_url = null; + mockStorageData.set(id, server); + } + return [] as unknown as T[]; + } + + if (query.includes("SELECT")) { + if (query.includes("WHERE callback_url")) { + const url = values[0] as string; + for (const server of mockStorageData.values()) { + if (server.callback_url === url) { + return [server] as unknown as T[]; + } + } + return [] as unknown as T[]; + } + return Array.from(mockStorageData.values()) as unknown as T[]; + } + + return [] as unknown as T[]; + }, + { + get: (key: string) => mockKVData.get(key) as T | undefined, + put: (key: string, value: unknown) => { + mockKVData.set(key, value); + }, + list: vi.fn(), + delete: vi.fn() + } + ); + + manager = new MCPClientManager("test-client", "1.0.0", { + storage: mockStorage + }); }); describe("Connection Reuse During OAuth", () => { @@ -44,48 +125,60 @@ describe("MCPClientManager OAuth Integration", () => { }); describe("Callback URL Management", () => { - it("should register and unregister callback URLs", () => { + it("should recognize callback URLs from database", async () => { const callbackUrl1 = "http://localhost:3000/callback/server1"; const callbackUrl2 = "http://localhost:3000/callback/server2"; - // Register callback URLs - manager.registerCallbackUrl(callbackUrl1); - manager.registerCallbackUrl(callbackUrl2); + // Save servers with callback URLs to database + manager.saveServer({ + id: "server1", + name: "Test Server 1", + server_url: "http://test1.com", + callback_url: callbackUrl1, + client_id: null, + auth_url: null, + server_options: null + }); + manager.saveServer({ + id: "server2", + name: "Test Server 2", + server_url: "http://test2.com", + callback_url: callbackUrl2, + client_id: null, + auth_url: null, + server_options: null + }); // Test callback recognition expect( - manager.isCallbackRequest(new Request(`${callbackUrl1}?code=test`)) + await manager.isCallbackRequest( + new Request(`${callbackUrl1}?code=test`) + ) ).toBe(true); expect( - manager.isCallbackRequest(new Request(`${callbackUrl2}?code=test`)) + await manager.isCallbackRequest( + new Request(`${callbackUrl2}?code=test`) + ) ).toBe(true); expect( - manager.isCallbackRequest(new Request("http://other.com/callback")) + await manager.isCallbackRequest( + new Request("http://other.com/callback") + ) ).toBe(false); - // Unregister callback URL - manager.unregisterCallbackUrl("server1"); + // Remove server from database + manager.removeServer("server1"); - // Should no longer recognize the unregistered callback + // Should no longer recognize the removed server's callback expect( - manager.isCallbackRequest(new Request(`${callbackUrl1}?code=test`)) + await manager.isCallbackRequest( + new Request(`${callbackUrl1}?code=test`) + ) ).toBe(false); expect( - manager.isCallbackRequest(new Request(`${callbackUrl2}?code=test`)) - ).toBe(true); - }); - - it("should not register duplicate callback URLs", () => { - const callbackUrl = "http://localhost:3000/callback/server1"; - - // Register the same URL multiple times - manager.registerCallbackUrl(callbackUrl); - manager.registerCallbackUrl(callbackUrl); - manager.registerCallbackUrl(callbackUrl); - - // Verify no duplicates by testing callback recognition still works with one registration - expect( - manager.isCallbackRequest(new Request(`${callbackUrl}?code=test`)) + await manager.isCallbackRequest( + new Request(`${callbackUrl2}?code=test`) + ) ).toBe(true); }); @@ -95,8 +188,16 @@ describe("MCPClientManager OAuth Integration", () => { const authCode = "test-auth-code"; const callbackUrl = `http://localhost:3000/callback/${serverId}`; - // Register callback URL - manager.registerCallbackUrl(callbackUrl); + // Save server to database with callback URL + manager.saveServer({ + id: serverId, + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); // Create real connection with authProvider and mock its methods const mockAuthProvider = { @@ -176,7 +277,15 @@ describe("MCPClientManager OAuth Integration", () => { it("should handle OAuth error response from provider", async () => { const callbackUrl = "http://localhost:3000/callback/server1"; - manager.registerCallbackUrl(callbackUrl); + manager.saveServer({ + id: "server1", + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); const callbackRequest = new Request( `${callbackUrl}?error=access_denied&error_description=User%20denied%20access` @@ -191,7 +300,15 @@ describe("MCPClientManager OAuth Integration", () => { it("should throw error for callback without code or error", async () => { const callbackUrl = "http://localhost:3000/callback/server1"; - manager.registerCallbackUrl(callbackUrl); + manager.saveServer({ + id: "server1", + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); const callbackRequest = new Request(`${callbackUrl}?state=test`); @@ -202,7 +319,15 @@ describe("MCPClientManager OAuth Integration", () => { it("should throw error for callback without state", async () => { const callbackUrl = "http://localhost:3000/callback/server1"; - manager.registerCallbackUrl(callbackUrl); + manager.saveServer({ + id: "server1", + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); const callbackRequest = new Request(`${callbackUrl}?code=test`); @@ -213,7 +338,15 @@ describe("MCPClientManager OAuth Integration", () => { it("should throw error for callback with non-existent server", async () => { const callbackUrl = "http://localhost:3000/callback/non-existent"; - manager.registerCallbackUrl(callbackUrl); + manager.saveServer({ + id: "non-existent", + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); const callbackRequest = new Request( `${callbackUrl}?code=test&state=client` @@ -224,12 +357,20 @@ describe("MCPClientManager OAuth Integration", () => { ).rejects.toThrow("Could not find serverId: non-existent"); }); - it("should throw error for callback when not in authenticating state", async () => { + it("should handle duplicate callback when already in ready state", async () => { const serverId = "test-server"; const callbackUrl = `http://localhost:3000/callback/${serverId}`; - manager.registerCallbackUrl(callbackUrl); + manager.saveServer({ + id: serverId, + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); - // Create real connection in ready state (not authenticating) + // Create real connection in ready state (simulates duplicate callback) const connection = new MCPClientConnection( new URL("http://example.com"), { name: "test-client", version: "1.0.0" }, @@ -239,7 +380,7 @@ describe("MCPClientManager OAuth Integration", () => { // Mock methods and set state connection.init = vi.fn().mockResolvedValue(undefined); connection.client.close = vi.fn().mockResolvedValue(undefined); - connection.connectionState = "ready"; // Not authenticating + connection.connectionState = "ready"; // Already authenticated manager.mcpConnections[serverId] = connection; @@ -247,10 +388,779 @@ describe("MCPClientManager OAuth Integration", () => { `${callbackUrl}?code=test&state=client` ); + // Should gracefully handle duplicate callback by returning success + const result = await manager.handleCallbackRequest(callbackRequest); + expect(result.authSuccess).toBe(true); + expect(result.serverId).toBe(serverId); + }); + + it("should error when callback received for connection in failed state", async () => { + const serverId = "test-server"; + const callbackUrl = `http://localhost:3000/callback/${serverId}`; + manager.saveServer({ + id: serverId, + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); + + // Create connection in failed state + const connection = new MCPClientConnection( + new URL("http://example.com"), + { name: "test-client", version: "1.0.0" }, + { transport: {}, client: {} } + ); + + connection.init = vi.fn().mockResolvedValue(undefined); + connection.client.close = vi.fn().mockResolvedValue(undefined); + connection.connectionState = "failed"; // Connection previously failed + + manager.mcpConnections[serverId] = connection; + + const callbackRequest = new Request( + `${callbackUrl}?code=test&state=client` + ); + + // Should error - failed connections need to be recreated, not re-authenticated + await expect( + manager.handleCallbackRequest(callbackRequest) + ).rejects.toThrow( + 'Failed to authenticate: the client is in "failed" state, expected "authenticating"' + ); + }); + }); + + describe("OAuth Security", () => { + it("should clear callback_url and auth_url after successful authentication", async () => { + const serverId = "test-server"; + const callbackUrl = `http://localhost:3000/callback/${serverId}`; + const authUrl = "https://auth.example.com/authorize"; + + // Save server with auth_url and callback_url + manager.saveServer({ + id: serverId, + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: "test-client-id", + auth_url: authUrl, + server_options: null + }); + + // Verify initial state + let server = mockStorageData.get(serverId); + expect(server).toBeDefined(); + expect(server?.callback_url).toBe(callbackUrl); + expect(server?.auth_url).toBe(authUrl); + + // Create connection with auth provider + const mockAuthProvider = { + authUrl: undefined, + clientId: undefined, + serverId: undefined, + redirectUrl: "http://localhost:3000/callback", + clientMetadata: { + client_name: "test-client", + client_uri: "http://localhost:3000", + redirect_uris: ["http://localhost:3000/callback"] + }, + tokens: vi.fn(), + saveTokens: vi.fn(), + clientInformation: vi.fn(), + saveClientInformation: vi.fn(), + redirectToAuthorization: vi.fn(), + saveCodeVerifier: vi.fn(), + codeVerifier: vi.fn() + }; + + const connection = new MCPClientConnection( + new URL("http://test.com"), + { name: "test-client", version: "1.0.0" }, + { + transport: { type: "auto", authProvider: mockAuthProvider }, + client: {} + } + ); + + connection.init = vi.fn().mockResolvedValue(undefined); + connection.client.close = vi.fn().mockResolvedValue(undefined); + connection.connectionState = "authenticating"; + connection.completeAuthorization = vi.fn().mockResolvedValue(undefined); + + manager.mcpConnections[serverId] = connection; + + // Handle callback + const callbackRequest = new Request( + `${callbackUrl}?code=test-code&state=test-state` + ); + const result = await manager.handleCallbackRequest(callbackRequest); + + expect(result.authSuccess).toBe(true); + + // Verify callback_url and auth_url were cleared + server = mockStorageData.get(serverId); + expect(server).toBeDefined(); + expect(server?.callback_url).toBe(""); + expect(server?.auth_url).toBe(null); + }); + + it("should prevent second callback attempt after auth_url is cleared", async () => { + const serverId = "test-server"; + const callbackUrl = `http://localhost:3000/callback/${serverId}`; + + // Save server with cleared callback_url (simulating post-auth state) + manager.saveServer({ + id: serverId, + name: "Test Server", + server_url: "http://test.com", + callback_url: "", // Already cleared + client_id: "test-client-id", + auth_url: null, // Already cleared + server_options: null + }); + + const callbackRequest = new Request( + `${callbackUrl}?code=malicious-code&state=test-state` + ); + + // Request should not be recognized as a callback + const isCallback = await manager.isCallbackRequest(callbackRequest); + expect(isCallback).toBe(false); + + // And handleCallbackRequest should fail await expect( manager.handleCallbackRequest(callbackRequest) + ).rejects.toThrow("No callback URI match found"); + }); + + it("should only match exact callback URLs from database", async () => { + const serverId = "test-server"; + const callbackUrl = `http://localhost:3000/callback/${serverId}`; + + manager.saveServer({ + id: serverId, + name: "Test Server", + server_url: "http://test.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, + server_options: null + }); + + // Exact match should work + expect( + await manager.isCallbackRequest(new Request(`${callbackUrl}?code=test`)) + ).toBe(true); + + // Prefix match should work (URL params) + expect( + await manager.isCallbackRequest( + new Request(`${callbackUrl}?code=test&state=abc`) + ) + ).toBe(true); + + // Different server ID should not match + expect( + await manager.isCallbackRequest( + new Request( + "http://localhost:3000/callback/different-server?code=test" + ) + ) + ).toBe(false); + + // Different host should not match + expect( + await manager.isCallbackRequest( + new Request(`http://evil.com/callback/${serverId}?code=test`) + ) + ).toBe(false); + + // Different path should not match + expect( + await manager.isCallbackRequest( + new Request(`http://localhost:3000/different/${serverId}?code=test`) + ) + ).toBe(false); + }); + }); + + describe("OAuth Connection Restoration", () => { + it("should restore OAuth connections from storage", async () => { + const serverId = "oauth-server"; + const callbackUrl = "http://localhost:3000/callback"; + const clientId = "stored-client-id"; + const authUrl = "https://auth.example.com/authorize"; + + // Save OAuth server to storage + manager.saveServer({ + id: serverId, + name: "OAuth Server", + server_url: "http://oauth-server.com", + callback_url: callbackUrl, + client_id: clientId, + auth_url: authUrl, + server_options: JSON.stringify({ + transport: { type: "auto" }, + client: {} + }) + }); + + await manager.restoreConnectionsFromStorage("test-agent"); + + // Verify connection was created in authenticating state + const connection = manager.mcpConnections[serverId]; + expect(connection).toBeDefined(); + expect(connection.connectionState).toBe("authenticating"); + + // Verify auth provider was set up + expect(connection.options.transport.authProvider).toBeDefined(); + expect(connection.options.transport.authProvider?.serverId).toBe( + serverId + ); + expect(connection.options.transport.authProvider?.clientId).toBe( + clientId + ); + }); + + it("should restore non-OAuth connections from storage", async () => { + const serverId = "regular-server"; + const callbackUrl = "http://localhost:3000/callback"; + + // Save non-OAuth server (no auth_url) + manager.saveServer({ + id: serverId, + name: "Regular Server", + server_url: "http://regular-server.com", + callback_url: callbackUrl, + client_id: null, + auth_url: null, // No OAuth + server_options: JSON.stringify({ + transport: { type: "sse", headers: { "X-Custom": "value" } }, + client: {} + }) + }); + + await manager.restoreConnectionsFromStorage("test-agent"); + + // Verify connection was registered and connected + const connection = manager.mcpConnections[serverId]; + expect(connection).toBeDefined(); + + // Verify auth provider was created (required for all connections) + expect(connection.options.transport.authProvider).toBeDefined(); + }); + + it("should handle empty server list gracefully", async () => { + await manager.restoreConnectionsFromStorage("test-agent"); + + // Should not throw and should have no connections + expect(Object.keys(manager.mcpConnections)).toHaveLength(0); + }); + + it("should restore mixed OAuth and non-OAuth servers", async () => { + // Save OAuth server + manager.saveServer({ + id: "oauth-server", + name: "OAuth Server", + server_url: "http://oauth.com", + callback_url: "http://localhost:3000/callback/oauth", + client_id: "oauth-client", + auth_url: "https://auth.example.com/authorize", + server_options: null + }); + + // Save regular server + manager.saveServer({ + id: "regular-server", + name: "Regular Server", + server_url: "http://regular.com", + callback_url: "http://localhost:3000/callback/regular", + client_id: null, + auth_url: null, + server_options: null + }); + + await manager.restoreConnectionsFromStorage("test-agent"); + + // Verify OAuth server is in authenticating state + expect(manager.mcpConnections["oauth-server"]).toBeDefined(); + expect(manager.mcpConnections["oauth-server"].connectionState).toBe( + "authenticating" + ); + + // Verify regular server was connected + expect(manager.mcpConnections["regular-server"]).toBeDefined(); + }); + }); + + describe("registerServer() and connectToServer()", () => { + it("should register a server and save to storage", () => { + const id = "test-server-1"; + const url = "http://example.com/mcp"; + const name = "Test Server"; + const callbackUrl = "http://localhost:3000/callback"; + + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" } + }); + + // Verify connection was created + expect(manager.mcpConnections[id]).toBeDefined(); + expect(manager.mcpConnections[id].url.toString()).toBe(url); + + // Verify saved to storage + const servers = mockStorageData.get(id); + expect(servers).toBeDefined(); + expect(servers?.name).toBe(name); + expect(servers?.server_url).toBe(url); + expect(servers?.callback_url).toBe(callbackUrl); + }); + + it("should skip registering if server already exists", () => { + const id = "existing-server"; + const url = "http://example.com/mcp"; + const name = "Existing Server"; + const callbackUrl = "http://localhost:3000/callback"; + + // Register once + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" } + }); + const firstConnection = manager.mcpConnections[id]; + + // Try to register again + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" } + }); + const secondConnection = manager.mcpConnections[id]; + + // Should be the same connection object + expect(secondConnection).toBe(firstConnection); + }); + + it("should save auth URL and client ID when registering OAuth server", () => { + const id = "oauth-server"; + const url = "http://oauth.example.com/mcp"; + const name = "OAuth Server"; + const callbackUrl = "http://localhost:3000/callback"; + const authUrl = "https://auth.example.com/authorize"; + const clientId = "test-client-id"; + + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" }, + authUrl, + clientId + }); + + // Verify OAuth info saved to storage + const server = mockStorageData.get(id); + expect(server?.auth_url).toBe(authUrl); + expect(server?.client_id).toBe(clientId); + }); + + it("should throw error when connecting to non-registered server", async () => { + await expect( + manager.connectToServer("non-existent-server") ).rejects.toThrow( - "Failed to authenticate: the client isn't in the `authenticating` state" + "Server non-existent-server is not registered. Call registerServer() first." + ); + }); + + it("should update storage with OAuth info after connection", async () => { + const id = "test-oauth-server"; + const url = "http://oauth.example.com/mcp"; + const name = "OAuth Server"; + const callbackUrl = "http://localhost:3000/callback"; + + // Create a mock auth provider that returns auth URL + const mockAuthProvider = { + serverId: id, + clientId: "mock-client-id", + authUrl: "https://auth.example.com/authorize", + redirectUrl: callbackUrl, + clientMetadata: { + client_name: "test-client", + redirect_uris: [callbackUrl] + }, + tokens: vi.fn(), + saveTokens: vi.fn(), + clientInformation: vi.fn(), + saveClientInformation: vi.fn(), + redirectToAuthorization: vi.fn((url) => { + mockAuthProvider.authUrl = url.toString(); + }), + saveCodeVerifier: vi.fn(), + codeVerifier: vi.fn() + }; + + // Register server with auth provider + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { + type: "auto", + authProvider: mockAuthProvider + } + }); + + // Mock the connection to return authenticating state + const conn = manager.mcpConnections[id]; + conn.init = vi.fn().mockImplementation(async () => { + conn.connectionState = "authenticating"; + }); + + // Connect to server + const result = await manager.connectToServer(id); + + // Verify auth URL is returned + expect(result.authUrl).toBe(mockAuthProvider.authUrl); + expect(result.clientId).toBe(mockAuthProvider.clientId); + + // Verify storage was updated with OAuth info + const server = mockStorageData.get(id); + expect(server?.auth_url).toBe(mockAuthProvider.authUrl); + expect(server?.client_id).toBe(mockAuthProvider.clientId); + }); + + it("should fire onConnected event for non-OAuth servers", async () => { + const id = "non-oauth-server"; + const url = "http://example.com/mcp"; + const name = "Non-OAuth Server"; + const callbackUrl = "http://localhost:3000/callback"; + + const onConnectedSpy = vi.fn(); + manager.onConnected(onConnectedSpy); + + // Register server + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" } + }); + + // Mock connection to go straight to ready state + const conn = manager.mcpConnections[id]; + conn.init = vi.fn().mockImplementation(async () => { + conn.connectionState = "ready"; + }); + + // Connect to server + await manager.connectToServer(id); + + // Verify onConnected was fired + expect(onConnectedSpy).toHaveBeenCalledWith(id); + }); + + it("should not fire onConnected event for OAuth servers in authenticating state", async () => { + const id = "oauth-server"; + const url = "http://oauth.example.com/mcp"; + const name = "OAuth Server"; + const callbackUrl = "http://localhost:3000/callback"; + + const onConnectedSpy = vi.fn(); + manager.onConnected(onConnectedSpy); + + const mockAuthProvider = { + serverId: id, + clientId: "mock-client-id", + authUrl: "https://auth.example.com/authorize", + redirectUrl: callbackUrl, + clientMetadata: { + client_name: "test-client", + redirect_uris: [callbackUrl] + }, + tokens: vi.fn(), + saveTokens: vi.fn(), + clientInformation: vi.fn(), + saveClientInformation: vi.fn(), + redirectToAuthorization: vi.fn(), + saveCodeVerifier: vi.fn(), + codeVerifier: vi.fn() + }; + + // Register server + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { + type: "auto", + authProvider: mockAuthProvider + } + }); + + // Mock connection to stay in authenticating state + const conn = manager.mcpConnections[id]; + conn.init = vi.fn().mockImplementation(async () => { + conn.connectionState = "authenticating"; + }); + + // Connect to server + await manager.connectToServer(id); + + // Verify onConnected was NOT fired (OAuth not complete) + expect(onConnectedSpy).not.toHaveBeenCalled(); + }); + + it("should handle OAuth code reconnection", async () => { + const id = "oauth-reconnect-server"; + const url = "http://oauth.example.com/mcp"; + const name = "OAuth Reconnect Server"; + const callbackUrl = "http://localhost:3000/callback"; + const oauthCode = "test-auth-code"; + + // Register server + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" } + }); + + // Mock connection methods + const conn = manager.mcpConnections[id]; + conn.completeAuthorization = vi.fn().mockResolvedValue(undefined); + conn.establishConnection = vi.fn().mockResolvedValue(undefined); + + // Connect with OAuth code + const result = await manager.connectToServer(id, { oauthCode }); + + // Verify OAuth completion was called + expect(conn.completeAuthorization).toHaveBeenCalledWith(oauthCode); + expect(conn.establishConnection).toHaveBeenCalled(); + + // Result should be empty for successful OAuth completion + expect(result.authUrl).toBeUndefined(); + expect(result.clientId).toBeUndefined(); + }); + }); + + describe("getAITools() integration", () => { + it("should return AI SDK tools after registering and connecting to server", async () => { + const id = "test-mcp-server"; + const url = "http://example.com/mcp"; + const name = "Test MCP Server"; + const callbackUrl = "http://localhost:3000/callback"; + + // Initialize jsonSchema (required for getAITools) + await manager.ensureJsonSchema(); + + // Register server + manager.registerServer(id, { + url, + name, + callbackUrl, + client: {}, + transport: { type: "auto" } + }); + + // Mock the connection to simulate a successful connection with tools + const conn = manager.mcpConnections[id]; + + // Mock init to reach ready state + conn.init = vi.fn().mockImplementation(async () => { + conn.connectionState = "ready"; + + // Simulate discovered tools + conn.tools = [ + { + name: "test_tool", + description: "A test tool", + inputSchema: { + type: "object", + properties: { + message: { + type: "string", + description: "Test message" + } + }, + required: ["message"] + } + } + ]; + }); + + // Mock callTool + conn.client.callTool = vi.fn().mockResolvedValue({ + content: [{ type: "text", text: "Tool result" }] + }); + + // Connect to server + await manager.connectToServer(id); + + // Verify connection is ready + expect(conn.connectionState).toBe("ready"); + expect(conn.tools).toHaveLength(1); + + // Get AI tools + const tools = manager.getAITools(); + + // Verify tools are properly formatted for AI SDK + expect(tools).toBeDefined(); + + // Tool name should be namespaced with server ID + const toolKey = `tool_${id.replace(/-/g, "")}_test_tool`; + expect(tools[toolKey]).toBeDefined(); + + // Verify tool structure + const tool = tools[toolKey]; + expect(tool.description).toBe("A test tool"); + expect(tool.execute).toBeDefined(); + expect(tool.inputSchema).toBeDefined(); + + // Test tool execution + const result = await tool.execute!( + { message: "test" }, + {} as ToolCallOptions + ); + expect(result).toBeDefined(); + expect(conn.client.callTool).toHaveBeenCalledWith( + { + name: "test_tool", + arguments: { message: "test" }, + serverId: id + }, + undefined, + undefined + ); + }); + + it("should aggregate tools from multiple connected servers", async () => { + const server1Id = "server-1"; + const server2Id = "server-2"; + + // Initialize jsonSchema + await manager.ensureJsonSchema(); + + // Register and connect first server + manager.registerServer(server1Id, { + url: "http://server1.com/mcp", + name: "Server 1", + callbackUrl: "http://localhost:3000/callback", + client: {}, + transport: { type: "auto" } + }); + + const conn1 = manager.mcpConnections[server1Id]; + conn1.init = vi.fn().mockImplementation(async () => { + conn1.connectionState = "ready"; + conn1.tools = [ + { + name: "tool_one", + description: "Tool from server 1", + inputSchema: { type: "object", properties: {} } + } + ]; + }); + conn1.client.callTool = vi.fn().mockResolvedValue({ + content: [{ type: "text", text: "Result 1" }] + }); + + await manager.connectToServer(server1Id); + + // Register and connect second server + manager.registerServer(server2Id, { + url: "http://server2.com/mcp", + name: "Server 2", + callbackUrl: "http://localhost:3000/callback", + client: {}, + transport: { type: "auto" } + }); + + const conn2 = manager.mcpConnections[server2Id]; + conn2.init = vi.fn().mockImplementation(async () => { + conn2.connectionState = "ready"; + conn2.tools = [ + { + name: "tool_two", + description: "Tool from server 2", + inputSchema: { type: "object", properties: {} } + } + ]; + }); + conn2.client.callTool = vi.fn().mockResolvedValue({ + content: [{ type: "text", text: "Result 2" }] + }); + + await manager.connectToServer(server2Id); + + // Get AI tools + const tools = manager.getAITools(); + + // Verify both tools are available + const tool1Key = `tool_${server1Id.replace(/-/g, "")}_tool_one`; + const tool2Key = `tool_${server2Id.replace(/-/g, "")}_tool_two`; + + expect(tools[tool1Key]).toBeDefined(); + expect(tools[tool2Key]).toBeDefined(); + expect(tools[tool1Key].description).toBe("Tool from server 1"); + expect(tools[tool2Key].description).toBe("Tool from server 2"); + + // Test both tools execute correctly + await tools[tool1Key].execute!({}, {} as ToolCallOptions); + expect(conn1.client.callTool).toHaveBeenCalledWith( + { + name: "tool_one", + arguments: {}, + serverId: server1Id + }, + undefined, + undefined + ); + + await tools[tool2Key].execute!({}, {} as ToolCallOptions); + expect(conn2.client.callTool).toHaveBeenCalledWith( + { + name: "tool_two", + arguments: {}, + serverId: server2Id + }, + undefined, + undefined + ); + }); + + it("should throw error if jsonSchema not initialized", () => { + // Create a new manager without initializing jsonSchema + const newManager = new MCPClientManager("test-client", "1.0.0", { + storage: new AgentMCPStorageAdapter( + >() => [] as T[], + { + get: () => undefined, + put: () => {}, + list: vi.fn(), + delete: vi.fn() + } + ) + }); + + expect(() => newManager.getAITools()).toThrow( + "jsonSchema not initialized." ); }); }); diff --git a/packages/agents/src/tests/mcp/handler.test.ts b/packages/agents/src/tests/mcp/handler.test.ts index 2bae8098..67125291 100644 --- a/packages/agents/src/tests/mcp/handler.test.ts +++ b/packages/agents/src/tests/mcp/handler.test.ts @@ -1,6 +1,9 @@ import { createExecutionContext, env } from "cloudflare:test"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import type { + CallToolResult, + JSONRPCError +} from "@modelcontextprotocol/sdk/types.js"; import { describe, expect, it } from "vitest"; import { createMcpHandler } from "../../mcp/handler"; import { z } from "zod"; @@ -395,11 +398,11 @@ describe("createMcpHandler", () => { expect(response.status).toBe(500); expect(response.headers.get("Content-Type")).toBe("application/json"); - const body = (await response.json()) as any; - expect(body.jsonrpc).toBe("2.0"); - expect(body.error).toBeDefined(); - expect(body.error.code).toBe(-32603); - expect(body.error.message).toBe("Transport error"); + const body = await response.json(); + expect((body as JSONRPCError)?.jsonrpc).toBe("2.0"); + expect((body as JSONRPCError)?.error).toBeDefined(); + expect((body as JSONRPCError)?.error?.code).toBe(-32603); + expect((body as JSONRPCError)?.error?.message).toBe("Transport error"); }); it("should return generic error message for non-Error exceptions", async () => { @@ -438,8 +441,10 @@ describe("createMcpHandler", () => { const response = await handler(request, env, ctx); expect(response.status).toBe(500); - const body = (await response.json()) as any; - expect(body.error.message).toBe("Internal server error"); + const body = await response.json(); + expect((body as JSONRPCError)?.error?.message).toBe( + "Internal server error" + ); }); }); }); diff --git a/packages/agents/src/tests/mcp/oauth2-mcp-client.test.ts b/packages/agents/src/tests/mcp/oauth2-mcp-client.test.ts index 4a4993ce..4d9d512b 100644 --- a/packages/agents/src/tests/mcp/oauth2-mcp-client.test.ts +++ b/packages/agents/src/tests/mcp/oauth2-mcp-client.test.ts @@ -7,438 +7,257 @@ declare module "cloudflare:test" { interface ProvidedEnv extends Env {} } -describe("OAuth2 MCP Client", () => { - it("hibernated durable object should restore MCP state from database during OAuth callback", async () => { - // Use idFromName to ensure we get the same DO instance across requests +describe("OAuth2 MCP Client - Hibernation", () => { + it("should restore MCP connections from database on wake-up", async () => { const agentId = env.TestOAuthAgent.idFromName("test-oauth-hibernation"); const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const authUrl = "http://example.com/oauth/authorize"; + const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; - // Initialize the agent + // Insert persisted MCP server (simulating pre-hibernation state) + agentStub.sql` + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test-oauth-server"}, ${"http://example.com/mcp"}, ${"test-client-id"}, ${authUrl}, ${fullCallbackUrl}, ${null}) + `; + + // Simulate DO wake-up await agentStub.setName("default"); await agentStub.onStart(); - // Reset the restoration flag to simulate fresh DO wake-up - await agentStub.resetMcpStateRestoredFlag(); + // Verify connection restored with authenticating state + expect(await agentStub.hasMcpConnection(serverId)).toBe(true); + }); - // Setup: Simulate a persisted MCP server that was saved before hibernation + it("should handle OAuth callback after hibernation", async () => { + const agentId = env.TestOAuthAgent.idFromName("test-oauth-callback"); + const agentStub = env.TestOAuthAgent.get(agentId); const serverId = nanoid(8); - const serverName = "test-oauth-server"; - const serverUrl = "http://example.com/mcp"; - const clientId = "test-client-id"; - const authUrl = "http://example.com/oauth/authorize"; const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; - - // Insert the MCP server record into the database (simulating pre-OAuth persistence) - agentStub.sql` - INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES ( - ${serverId}, - ${serverName}, - ${serverUrl}, - ${clientId}, - ${authUrl}, - ${callbackBaseUrl}, - ${null} - ) - `; - - // At this point, the DO has internal state only from database - // When it wakes up for the OAuth callback, it should restore state from the database - - // Verify callback URL is NOT registered before the callback (simulating hibernation) const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; - const isRegisteredBefore = await agentStub.isCallbackUrlRegistered( - `${fullCallbackUrl}?code=test&state=test` - ); - expect(isRegisteredBefore).toBe(false); - - // Simulate the OAuth callback request - const authCode = "test-auth-code"; - const state = "test-state"; - const callbackUrl = `${callbackBaseUrl}/${serverId}?code=${authCode}&state=${state}`; - const request = new Request(callbackUrl, { method: "GET" }); - const response = await agentStub.fetch(request); - - // The restoration worked if we get past the "Server not found" error - // The server should be found in the database and the callback URL should be restored - const responseText = await response.text(); + agentStub.sql` + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${fullCallbackUrl}, ${null}) + `; - // We should NOT get a 404 (that would mean restoration failed) - expect(response.status).not.toBe(404); - expect(responseText).not.toContain("not found in database"); + await agentStub.setName("default"); + await agentStub.onStart(); - // Verify the callback URL was restored/registered in memory during the request processing - const isRegisteredAfter = await agentStub.isCallbackUrlRegistered( - `${fullCallbackUrl}?code=test&state=test` + const response = await agentStub.fetch( + new Request(`${fullCallbackUrl}?code=test-code&state=test-state`) ); - expect(isRegisteredAfter).toBe(true); - - // Verify connection was created in authenticating state - const hasConnection = await agentStub.hasMcpConnection(serverId); - expect(hasConnection).toBe(true); - // Verify database record still exists after callback - const serverAfter = await agentStub.getMcpServerFromDb(serverId); - expect(serverAfter).not.toBeNull(); - expect(serverAfter?.id).toBe(serverId); + expect(response.status).not.toBe(404); + expect(await response.text()).not.toContain("Could not find serverId"); }); +}); - it("should restore connection when callback URL is registered but connection is missing", async () => { - // Edge case: callback URL exists in memory but connection object is missing - const agentId = env.TestOAuthAgent.idFromName("test-partial-state"); +describe("OAuth2 MCP Client - Callback Handling", () => { + it("should process OAuth callback with valid connection", async () => { + const agentId = env.TestOAuthAgent.newUniqueId(); const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; await agentStub.setName("default"); await agentStub.onStart(); - await agentStub.resetMcpStateRestoredFlag(); - - const serverId = nanoid(8); - const serverName = "test-server"; - const serverUrl = "http://example.com/mcp"; - const clientId = "test-client-id"; - const authUrl = "http://example.com/oauth/authorize"; - const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; - // Insert server record in database agentStub.sql` - INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES ( - ${serverId}, - ${serverName}, - ${serverUrl}, - ${clientId}, - ${authUrl}, - ${callbackBaseUrl}, - ${null} - ) - `; - - // Simulate partial state: callback URL is registered but connection is missing - const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client-id"}, ${"http://example.com/auth"}, ${fullCallbackUrl}, ${null}) + `; + await agentStub.setupMockMcpConnection( serverId, - serverName, - serverUrl, - callbackBaseUrl + "test", + "http://example.com/mcp", + callbackBaseUrl, + "client-id" ); + await agentStub.setupMockOAuthState(serverId, "test-code", "test-state"); - // Verify callback URL IS registered - const isRegisteredBefore = await agentStub.isCallbackUrlRegistered( - `${fullCallbackUrl}?code=test&state=test` + const response = await agentStub.fetch( + new Request(`${fullCallbackUrl}?code=test-code&state=test-state`) ); - expect(isRegisteredBefore).toBe(true); - - // Now REMOVE the connection from mcpConnections to simulate the bug scenario - await agentStub.removeMcpConnection(serverId); - - // Verify connection is missing - const connectionExists = await agentStub.hasMcpConnection(serverId); - expect(connectionExists).toBe(false); - - const authCode = "test-code"; - const state = "test-state"; - const callbackUrl = `${callbackBaseUrl}/${serverId}?code=${authCode}&state=${state}`; - const request = new Request(callbackUrl, { method: "GET" }); - const response = await agentStub.fetch(request); - - // Should not fail with "Could not find serverId: xxx" - const responseText = await response.text(); - expect(responseText).not.toContain("Could not find serverId"); - expect(response.status).not.toBe(404); - - // Verify the callback URL is still registered after restoration - const isRegisteredAfter = await agentStub.isCallbackUrlRegistered( - `${fullCallbackUrl}?code=test&state=test` - ); - expect(isRegisteredAfter).toBe(true); + expect(response.status).toBe(200); }); - it("should handle callback when server record exists and connection is still in memory", async () => { + it("should clear auth_url after successful OAuth", async () => { const agentId = env.TestOAuthAgent.newUniqueId(); const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const authUrl = "http://example.com/oauth/authorize"; + const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; await agentStub.setName("default"); await agentStub.onStart(); - const serverId = nanoid(8); - const serverName = "test-server"; - const serverUrl = "http://example.com/mcp"; - const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; - - // Insert server record in database agentStub.sql` - INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES ( - ${serverId}, - ${serverName}, - ${serverUrl}, - ${"client-id"}, - ${"http://example.com/auth"}, - ${callbackBaseUrl}, - ${null} - ) - `; - - // Setup in-memory state (simulates non-hibernated DO) + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client-id"}, ${authUrl}, ${fullCallbackUrl}, ${null}) + `; + await agentStub.setupMockMcpConnection( serverId, - serverName, - serverUrl, - callbackBaseUrl + "test", + "http://example.com/mcp", + callbackBaseUrl, + "client-id" ); + await agentStub.setupMockOAuthState(serverId, "test-code", "test-state"); - // Verify callback URL is already registered - const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; - const isRegisteredBefore = await agentStub.isCallbackUrlRegistered( - `${fullCallbackUrl}?code=test&state=test` + await agentStub.fetch( + new Request(`${fullCallbackUrl}?code=test-code&state=test-state`) ); - expect(isRegisteredBefore).toBe(true); - // Set up mock OAuth state - const authCode = "test-code"; - const state = "test-state"; - await agentStub.setupMockOAuthState(serverId, authCode, state); + const serverAfter = await agentStub.getMcpServerFromDb(serverId); + expect(serverAfter?.auth_url).toBeNull(); + }); +}); - const callbackUrl = `${callbackBaseUrl}/${serverId}?code=${authCode}&state=${state}`; - const request = new Request(callbackUrl, { method: "GET" }); +describe("OAuth2 MCP Client - Error Handling", () => { + it("should reject callback without code parameter", async () => { + const agentId = env.TestOAuthAgent.newUniqueId(); + const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; - const response = await agentStub.fetch(request); + await agentStub.setName("default"); + await agentStub.onStart(); - // Should succeed - the restoration is idempotent - expect(response.status).toBe(200); + agentStub.sql` + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${fullCallbackUrl}, ${null}) + `; - // Verify callback URL is still registered (idempotent) - const isRegisteredAfter = await agentStub.isCallbackUrlRegistered( - `${fullCallbackUrl}?code=test&state=test` + const response = await agentStub.fetch( + new Request(`${fullCallbackUrl}?state=test-state`) ); - expect(isRegisteredAfter).toBe(true); + expect(response.status).toBeGreaterThanOrEqual(400); }); - it("should not restore state for non-callback requests", async () => { - const ctx = createExecutionContext(); - + it("should reject callback without state parameter", async () => { const agentId = env.TestOAuthAgent.newUniqueId(); const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; await agentStub.setName("default"); await agentStub.onStart(); - const regularUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}`; - const request = new Request(regularUrl, { method: "GET" }); - - const response = await worker.fetch(request, env, ctx); + agentStub.sql` + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${fullCallbackUrl}, ${null}) + `; - expect(response.status).toBe(200); - const text = await response.text(); - expect(text).toBe("Test OAuth Agent"); + const response = await agentStub.fetch( + new Request(`${fullCallbackUrl}?code=test-code`) + ); + expect(response.status).toBeGreaterThanOrEqual(400); }); +}); - describe("OAuth Error Handling", () => { - it("should handle callback with missing code parameter", async () => { - const agentId = env.TestOAuthAgent.newUniqueId(); - const agentStub = env.TestOAuthAgent.get(agentId); - - await agentStub.setName("default"); - await agentStub.onStart(); - await agentStub.resetMcpStateRestoredFlag(); - - const serverId = nanoid(8); - const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; - - // Insert OAuth server - agentStub.sql` - INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${callbackBaseUrl}, ${null}) - `; - - // Make callback request without code parameter - const callbackUrl = `${callbackBaseUrl}/${serverId}?state=test-state`; - const request = new Request(callbackUrl, { method: "GET" }); - - const response = await agentStub.fetch(request); - - // Should return an error (not crash) - expect(response.status).toBeGreaterThanOrEqual(400); - }); - - it("should handle callback with missing state parameter", async () => { - const agentId = env.TestOAuthAgent.newUniqueId(); - const agentStub = env.TestOAuthAgent.get(agentId); - - await agentStub.setName("default"); - await agentStub.onStart(); - await agentStub.resetMcpStateRestoredFlag(); +describe("OAuth2 MCP Client - Redirect Behavior", () => { + it("should redirect to success URL after OAuth", async () => { + const agentId = env.TestOAuthAgent.newUniqueId(); + const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const callbackBaseUrl = `http://example.com/agents/oauth/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; - const serverId = nanoid(8); - const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; + await agentStub.setName("default"); + await agentStub.onStart(); + await agentStub.configureOAuthForTest({ successRedirect: "/dashboard" }); - // Insert OAuth server - agentStub.sql` - INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${callbackBaseUrl}, ${null}) - `; + agentStub.sql` + INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${fullCallbackUrl}, ${null}) + `; - // Make callback request without state parameter - const callbackUrl = `${callbackBaseUrl}/${serverId}?code=test-code`; - const request = new Request(callbackUrl, { method: "GET" }); + await agentStub.setupMockMcpConnection( + serverId, + "test", + "http://example.com/mcp", + callbackBaseUrl, + "client" + ); + await agentStub.setupMockOAuthState(serverId, "test-code", "test-state"); - const response = await agentStub.fetch(request); + const response = await agentStub.fetch( + new Request(`${fullCallbackUrl}?code=test-code&state=test-state`, { + redirect: "manual" + }) + ); - // Should return an error (not crash) - expect(response.status).toBeGreaterThanOrEqual(400); - }); + expect(response.status).toBe(302); + expect(response.headers.get("Location")).toBe( + "http://example.com/dashboard" + ); }); - it("should clear auth_url from database after successful OAuth callback", async () => { + it("should redirect to error URL on OAuth failure", async () => { const agentId = env.TestOAuthAgent.newUniqueId(); const agentStub = env.TestOAuthAgent.get(agentId); + const serverId = nanoid(8); + const callbackBaseUrl = `http://example.com/agents/oauth/${agentId.toString()}/callback`; + const fullCallbackUrl = `${callbackBaseUrl}/${serverId}`; await agentStub.setName("default"); await agentStub.onStart(); + await agentStub.configureOAuthForTest({ errorRedirect: "/error" }); - const serverId = nanoid(8); - const serverName = "test-oauth-server"; - const serverUrl = "http://example.com/mcp"; - const clientId = "test-client-id"; - const authUrl = "http://example.com/oauth/authorize"; - const callbackBaseUrl = `http://example.com/agents/test-o-auth-agent/${agentId.toString()}/callback`; - - // Insert MCP server with auth_url agentStub.sql` INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES ( - ${serverId}, - ${serverName}, - ${serverUrl}, - ${clientId}, - ${authUrl}, - ${callbackBaseUrl}, - ${null} - ) + VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${fullCallbackUrl}, ${null}) `; - // Verify auth_url exists before callback - const serverBefore = await agentStub.getMcpServerFromDb(serverId); - expect(serverBefore).not.toBeNull(); - expect(serverBefore?.auth_url).toBe(authUrl); - - // Setup mock connection and OAuth state await agentStub.setupMockMcpConnection( serverId, - serverName, - serverUrl, - callbackBaseUrl + "test", + "http://example.com/mcp", + callbackBaseUrl, + "client" ); await agentStub.setupMockOAuthState(serverId, "test-code", "test-state"); - // Simulate successful OAuth callback - const authCode = "test-auth-code"; - const state = "test-state"; - const callbackUrl = `${callbackBaseUrl}/${serverId}?code=${authCode}&state=${state}`; - const request = new Request(callbackUrl, { method: "GET" }); + const response = await agentStub.fetch( + new Request(`${fullCallbackUrl}?error=access_denied&state=test-state`, { + redirect: "manual" + }) + ); - const response = await agentStub.fetch(request); - expect(response.status).toBe(200); + expect(response.status).toBe(302); + expect(response.headers.get("Location")).toMatch( + /^http:\/\/example\.com\/error\?error=/ + ); + }); +}); - // Verify auth_url is cleared after successful callback - const serverAfter = await agentStub.getMcpServerFromDb(serverId); - expect(serverAfter).not.toBeNull(); - expect(serverAfter?.auth_url).toBeNull(); +describe("OAuth2 MCP Client - Basic Functionality", () => { + it("should handle non-callback requests normally", async () => { + const ctx = createExecutionContext(); + const agentId = env.TestOAuthAgent.newUniqueId(); + const agentStub = env.TestOAuthAgent.get(agentId); - // Verify the server record still exists with other data intact - expect(serverAfter?.id).toBe(serverId); - expect(serverAfter?.name).toBe(serverName); - expect(serverAfter?.server_url).toBe(serverUrl); - expect(serverAfter?.client_id).toBe(clientId); - }); + await agentStub.setName("default"); + await agentStub.onStart(); + + const response = await worker.fetch( + new Request( + `http://example.com/agents/test-o-auth-agent/${agentId.toString()}` + ), + env, + ctx + ); - describe("OAuth Redirect Behavior", () => { - async function setupOAuthTest(config: { - successRedirect?: string; - errorRedirect?: string; - origin?: string; - }) { - const agentId = env.TestOAuthAgent.newUniqueId(); - const agentStub = env.TestOAuthAgent.get(agentId); - await agentStub.setName("default"); - await agentStub.onStart(); - await agentStub.configureOAuthForTest(config); - - const serverId = nanoid(8); - const origin = config.origin || "http://example.com"; - const callbackBaseUrl = `${origin}/agents/oauth/${agentId.toString()}/callback`; - - agentStub.sql` - INSERT INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) - VALUES (${serverId}, ${"test"}, ${"http://example.com/mcp"}, ${"client"}, ${"http://example.com/auth"}, ${callbackBaseUrl}, ${null}) - `; - - await agentStub.setupMockMcpConnection( - serverId, - "test", - "http://example.com/mcp", - callbackBaseUrl - ); - await agentStub.setupMockOAuthState(serverId, "test-code", "test-state"); - - return { agentStub, serverId, callbackBaseUrl }; - } - - it("should return 302 redirect with Location header on successful OAuth callback", async () => { - const { agentStub, serverId, callbackBaseUrl } = await setupOAuthTest({ - successRedirect: "/dashboard" - }); - - const response = await agentStub.fetch( - new Request( - `${callbackBaseUrl}/${serverId}?code=test-code&state=test-state`, - { method: "GET", redirect: "manual" } - ) - ); - - expect(response.status).toBe(302); - expect(response.headers.get("Location")).toBe( - "http://example.com/dashboard" - ); - }); - - it("should handle relative URLs in successRedirect", async () => { - const { agentStub, serverId, callbackBaseUrl } = await setupOAuthTest({ - successRedirect: "/success", - origin: "http://test.local" - }); - - const response = await agentStub.fetch( - new Request( - `${callbackBaseUrl}/${serverId}?code=test-code&state=test-state`, - { method: "GET", redirect: "manual" } - ) - ); - - expect(response.status).toBe(302); - expect(response.headers.get("Location")).toBe( - "http://test.local/success" - ); - }); - - it("should redirect to errorRedirect with error parameter on OAuth failure", async () => { - const { agentStub, serverId, callbackBaseUrl } = await setupOAuthTest({ - errorRedirect: "/error" - }); - - const response = await agentStub.fetch( - new Request( - `${callbackBaseUrl}/${serverId}?error=access_denied&state=test-state`, - { method: "GET", redirect: "manual" } - ) - ); - - expect(response.status).toBe(302); - expect(response.headers.get("Location")).toMatch( - /^http:\/\/example\.com\/error\?error=/ - ); - }); + expect(response.status).toBe(200); + expect(await response.text()).toBe("Test OAuth Agent"); }); }); diff --git a/packages/agents/src/tests/mcp/worker-transport.test.ts b/packages/agents/src/tests/mcp/worker-transport.test.ts index 1b83b907..0a580959 100644 --- a/packages/agents/src/tests/mcp/worker-transport.test.ts +++ b/packages/agents/src/tests/mcp/worker-transport.test.ts @@ -1,8 +1,12 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import type { + CallToolResult, + JSONRPCMessage +} from "@modelcontextprotocol/sdk/types.js"; import { describe, expect, it } from "vitest"; import { WorkerTransport, + type TransportState, type WorkerTransportOptions } from "../../mcp/worker-transport"; import { z } from "zod"; @@ -655,11 +659,11 @@ describe("WorkerTransport", () => { describe("Storage API - State Persistence", () => { it("should persist session state to storage", async () => { const server = createTestServer(); - let storedState: any = undefined; + let storedState: TransportState | undefined; const mockStorage = { get: async () => storedState, - set: async (state: any) => { + set: async (state: TransportState) => { storedState = state; } }; @@ -690,9 +694,9 @@ describe("WorkerTransport", () => { await transport.handleRequest(request); expect(storedState).toBeDefined(); - expect(storedState.sessionId).toBe("persistent-session"); - expect(storedState.initialized).toBe(true); - expect(storedState.protocolVersion).toBe("2025-06-18"); + expect(storedState?.sessionId).toBe("persistent-session"); + expect(storedState?.initialized).toBe(true); + expect(storedState?.protocolVersion).toBe("2025-06-18"); }); it("should restore session state from storage", async () => { @@ -969,7 +973,7 @@ describe("WorkerTransport", () => { const body = await response.json(); expect(body).toBeDefined(); - expect((body as any).jsonrpc).toBe("2.0"); + expect((body as JSONRPCMessage).jsonrpc).toBe("2.0"); }); it("should return SSE stream when enableJsonResponse is false", async () => { diff --git a/packages/agents/src/tests/worker.ts b/packages/agents/src/tests/worker.ts index bbc16e55..7f043b22 100644 --- a/packages/agents/src/tests/worker.ts +++ b/packages/agents/src/tests/worker.ts @@ -232,11 +232,21 @@ export class TestOAuthAgent extends Agent { async setupMockMcpConnection( serverId: string, - _serverName: string, + serverName: string, serverUrl: string, - callbackUrl: string + callbackUrl: string, + clientId?: string | null ): Promise { - this.mcp.registerCallbackUrl(`${callbackUrl}/${serverId}`); + // Save server to database with callback URL + this.mcp.saveServer({ + id: serverId, + name: serverName, + server_url: serverUrl, + callback_url: `${callbackUrl}/${serverId}`, + client_id: clientId ?? null, + auth_url: null, + server_options: null + }); this.mcp.mcpConnections[serverId] = this.createMockMcpConnection( serverId, serverUrl, @@ -292,8 +302,8 @@ export class TestOAuthAgent extends Agent { return servers.length > 0 ? servers[0] : null; } - isCallbackUrlRegistered(callbackUrl: string): boolean { - return this.mcp.isCallbackRequest(new Request(callbackUrl)); + async isCallbackUrlRegistered(callbackUrl: string): Promise { + return await this.mcp.isCallbackRequest(new Request(callbackUrl)); } removeMcpConnection(serverId: string): void { @@ -306,7 +316,7 @@ export class TestOAuthAgent extends Agent { resetMcpStateRestoredFlag(): void { // @ts-expect-error - accessing private property for testing - this._mcpStateRestored = false; + this._mcpConnectionsInitialized = false; } } diff --git a/site/ai-playground/src/app.tsx b/site/ai-playground/src/app.tsx index a42788e5..336520d7 100644 --- a/site/ai-playground/src/app.tsx +++ b/site/ai-playground/src/app.tsx @@ -8,6 +8,7 @@ import { McpServers } from "./components/McpServers"; import ModelSelector from "./components/ModelSelector"; import ViewCodeModal from "./components/ViewCodeModal"; import { ToolCallCard } from "./components/ToolCallCard"; +import { ReasoningCard } from "./components/ReasoningCard"; import { isToolUIPart, type UIMessage } from "ai"; import { useAgent } from "agents/react"; import type { MCPServersState } from "agents"; @@ -18,7 +19,7 @@ import type { McpComponentState } from "./components/McpServers"; const STORAGE_KEY = "playground_session_id"; const DEFAULT_PARAMS = { - model: "@hf/nousresearch/hermes-2-pro-mistral-7b", + model: "@cf/qwen/qwen3-30b-a3b-fp8", temperature: 0, stream: true, system: @@ -425,9 +426,9 @@ const App = () => { >
    - {messages.map((message) => ( -
    - {message.parts.map((part, i) => { + {messages.map((message) => { + const renderedParts = message.parts + .map((part, i) => { // Render text messages if (part.type === "text") { // Skip empty text parts (e.g., when message only contains tool calls) @@ -465,6 +466,24 @@ const App = () => { ); } + // Render reasoning + if (part.type === "reasoning") { + // Skip empty reasoning parts + if (!part.text || part.text.trim() === "") { + return null; + } + + return ( +
  • + +
  • + ); + } + // Render tool calls if (isToolUIPart(part)) { return ( @@ -500,9 +519,16 @@ const App = () => { } return null; - })} -
    - ))} + }) + .filter(Boolean); + + // Only render the message wrapper if there are actual parts to show + if (renderedParts.length === 0) { + return null; + } + + return
    {renderedParts}
    ; + })} {(loading || streaming) && (messages[messages.length - 1].role !== "assistant" || @@ -516,13 +542,21 @@ const App = () => { Assistant
-
- +
+
+
+
+
+
) : null} diff --git a/site/ai-playground/src/components/McpServers.tsx b/site/ai-playground/src/components/McpServers.tsx index db6e6d7a..0e94298a 100644 --- a/site/ai-playground/src/components/McpServers.tsx +++ b/site/ai-playground/src/components/McpServers.tsx @@ -24,9 +24,7 @@ type McpServersProps = { }; export function McpServers({ agent, mcpState, mcpLogs }: McpServersProps) { - const [serverUrl, setServerUrl] = useState(() => { - return sessionStorage.getItem("mcpServerUrl") || ""; - }); + const [serverUrl, setServerUrl] = useState(""); const [_transportType, _setTransportType] = useState<"auto" | "http" | "sse">( () => { return ( @@ -53,6 +51,12 @@ export function McpServers({ agent, mcpState, mcpLogs }: McpServersProps) { setIsActive(false); } }, [mcpState?.state, mcpState]); + + useEffect(() => { + if (mcpState?.url) { + setServerUrl(mcpState.url); + } + }, [mcpState?.url]); const [error, setError] = useState(""); const [isConnecting, setIsConnecting] = useState(false); @@ -377,9 +381,7 @@ export function McpServers({ agent, mcpState, mcpLogs }: McpServersProps) { placeholder="Enter MCP server URL" value={serverUrl} onChange={(e) => { - const newValue = e.target.value; - setServerUrl(newValue); - sessionStorage.setItem("mcpServerUrl", newValue); + setServerUrl(e.target.value); }} disabled={isActive} /> @@ -395,10 +397,21 @@ export function McpServers({ agent, mcpState, mcpLogs }: McpServersProps) { )}
diff --git a/site/ai-playground/src/components/ReasoningCard.tsx b/site/ai-playground/src/components/ReasoningCard.tsx new file mode 100644 index 00000000..5b8352ba --- /dev/null +++ b/site/ai-playground/src/components/ReasoningCard.tsx @@ -0,0 +1,58 @@ +import { useState } from "react"; + +interface ReasoningCardProps { + part: { + type: "reasoning"; + text: string; + state?: "streaming" | "done"; + }; +} + +export const ReasoningCard = ({ part }: ReasoningCardProps) => { + const [isExpanded, setIsExpanded] = useState(true); + + return ( +
+ + +
+
+          {part.text}
+        
+
+
+ ); +}; diff --git a/site/ai-playground/src/server.ts b/site/ai-playground/src/server.ts index 2efeecdc..a267956d 100644 --- a/site/ai-playground/src/server.ts +++ b/site/ai-playground/src/server.ts @@ -30,7 +30,7 @@ export interface PlaygroundState { */ export class Playground extends AIChatAgent { initialState = { - model: "@hf/nousresearch/hermes-2-pro-mistral-7b", + model: "@cf/qwen/qwen3-30b-a3b-fp8", temperature: 1, stream: true, system: @@ -75,6 +75,7 @@ export class Playground extends AIChatAgent { }); await this.ensureDestroy(); + const stream = createUIMessageStream({ execute: async ({ writer }) => { // Clean up incomplete tool calls to prevent API errors @@ -143,9 +144,9 @@ export class Playground extends AIChatAgent { await this.removeMcpServer(serverId); } else { // Disconnect all servers if no serverId provided - const mcpState = this.getMcpServers(); - for (const id of Object.keys(mcpState.servers)) { - await this.removeMcpServer(id); + const servers = this.mcp.listServers(); + for (const server of servers) { + await this.removeMcpServer(server.id); } } }