diff --git a/examples/clients/typescript/auth-test-broken1.ts b/examples/clients/typescript/auth-test-broken1.ts new file mode 100644 index 0000000..6fd6e5c --- /dev/null +++ b/examples/clients/typescript/auth-test-broken1.ts @@ -0,0 +1,99 @@ +#!/usr/bin/env node + +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { handle401, withOAuthRetry } from './helpers/withOAuthRetry.js'; +import { ConformanceOAuthProvider } from './helpers/ConformanceOAuthProvider.js'; +import { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js'; +import { + auth, + UnauthorizedError +} from '@modelcontextprotocol/sdk/client/auth.js'; + +export const handle401Broken = async ( + response: Response, + provider: ConformanceOAuthProvider, + next: FetchLike, + serverUrl: string | URL +): Promise => { + // BROKEN: Use root-based PRM discovery exclusively, regardless of input. + const resourceMetadataUrl = new URL( + '/.well-known/oauth-protected-resource', + typeof serverUrl === 'string' ? serverUrl : serverUrl.origin + ); + + let result = await auth(provider, { + serverUrl, + resourceMetadataUrl, + fetchFn: next + }); + + if (result === 'REDIRECT') { + // Ordinarily, we'd wait for the callback to be handled here, + // but in our conformance provider, we get the authorization code + // during the redirect handling, so we can go straight to + // retrying the auth step. + // await provider.waitForCallback(); + + const authorizationCode = await provider.getAuthCode(); + + // TODO: this retry logic should be incorporated into the typescript SDK + result = await auth(provider, { + serverUrl, + resourceMetadataUrl, + authorizationCode, + fetchFn: next + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError( + `Authentication failed with result: ${result}` + ); + } + } +}; + +async function main(): Promise { + const serverUrl = process.argv[2]; + + if (!serverUrl) { + console.error('Usage: auth-test '); + process.exit(1); + } + + console.log(`Connecting to MCP server at: ${serverUrl}`); + + const client = new Client( + { + name: 'test-auth-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + // Create a custom fetch that uses the OAuth middleware with retry logic + const oauthFetch = withOAuthRetry( + 'test-auth-client', + new URL(serverUrl), + handle401Broken + )(fetch); + + const transport = new StreamableHTTPClientTransport(new URL(serverUrl), { + fetch: oauthFetch + }); + + // Connect to the server - OAuth is handled automatically by the middleware + await client.connect(transport); + console.log('✅ Successfully connected to MCP server'); + + await client.listTools(); + console.log('✅ Successfully listed tools'); + + await transport.close(); + console.log('✅ Connection closed successfully'); + + process.exit(0); +} + +main(); diff --git a/examples/clients/typescript/auth-test.ts b/examples/clients/typescript/auth-test.ts index dee052b..cb0fd2d 100644 --- a/examples/clients/typescript/auth-test.ts +++ b/examples/clients/typescript/auth-test.ts @@ -2,8 +2,7 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; -import { ConformanceOAuthProvider } from './helpers/ConformanceOAuthProvider.js'; -import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'; +import { withOAuthRetry } from './helpers/withOAuthRetry.js'; async function main(): Promise { const serverUrl = process.argv[2]; @@ -25,47 +24,19 @@ async function main(): Promise { } ); - const authProvider = new ConformanceOAuthProvider( - 'http://localhost:3000/callback', - { - client_name: 'test-auth-client', - redirect_uris: ['http://localhost:3000/callback'] - } - ); + // Create a custom fetch that uses the OAuth middleware with retry logic + const oauthFetch = withOAuthRetry( + 'test-auth-client', + new URL(serverUrl) + )(fetch); - let transport = new StreamableHTTPClientTransport(new URL(serverUrl), { - authProvider + const transport = new StreamableHTTPClientTransport(new URL(serverUrl), { + fetch: oauthFetch }); - // Try to connect - handle OAuth if needed - try { - await client.connect(transport); - console.log('✅ Successfully connected to MCP server'); - } catch (error) { - if (error instanceof UnauthorizedError) { - console.log('🔐 OAuth required - handling authorization...'); - - // The provider will automatically fetch the auth code - const authCode = await authProvider.getAuthCode(); - - // Complete the auth flow - await transport.finishAuth(authCode); - - // Close the old transport - await transport.close(); - - // Create a new transport with the authenticated provider - transport = new StreamableHTTPClientTransport(new URL(serverUrl), { - authProvider: authProvider - }); - - // Connect with the new transport - await client.connect(transport); - console.log('✅ Successfully connected with authentication'); - } else { - throw error; - } - } + // Connect to the server - OAuth is handled automatically by the middleware + await client.connect(transport); + console.log('✅ Successfully connected to MCP server'); await client.listTools(); console.log('✅ Successfully listed tools'); diff --git a/examples/clients/typescript/helpers/withOAuthRetry.ts b/examples/clients/typescript/helpers/withOAuthRetry.ts new file mode 100644 index 0000000..fbdda08 --- /dev/null +++ b/examples/clients/typescript/helpers/withOAuthRetry.ts @@ -0,0 +1,109 @@ +import { + auth, + extractResourceMetadataUrl, + UnauthorizedError +} from '@modelcontextprotocol/sdk/client/auth.js'; +import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js'; +import type { Middleware } from '@modelcontextprotocol/sdk/client/middleware.js'; +import { ConformanceOAuthProvider } from './ConformanceOAuthProvider'; + +export const handle401 = async ( + response: Response, + provider: ConformanceOAuthProvider, + next: FetchLike, + serverUrl: string | URL +): Promise => { + const resourceMetadataUrl = extractResourceMetadataUrl(response); + + let result = await auth(provider, { + serverUrl, + resourceMetadataUrl, + fetchFn: next + }); + + if (result === 'REDIRECT') { + // Ordinarily, we'd wait for the callback to be handled here, + // but in our conformance provider, we get the authorization code + // during the redirect handling, so we can go straight to + // retrying the auth step. + // await provider.waitForCallback(); + + const authorizationCode = await provider.getAuthCode(); + + // TODO: this retry logic should be incorporated into the typescript SDK + result = await auth(provider, { + serverUrl, + resourceMetadataUrl, + authorizationCode, + fetchFn: next + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError( + `Authentication failed with result: ${result}` + ); + } + } +}; +/** + * Creates a fetch wrapper that handles OAuth authentication with retry logic. + * + * Unlike the SDK's withOAuth, this version: + * - Automatically handles authorization redirects by retrying with fresh tokens + * - Does not throw UnauthorizedError on redirect, but instead retries + * - Calls next() instead of throwing for redirect-based auth + * + * @param provider - OAuth client provider for authentication + * @param baseUrl - Base URL for OAuth server discovery (defaults to request URL domain) + * @returns A fetch middleware function + */ +export const withOAuthRetry = ( + clientName: string, + baseUrl?: string | URL, + handle401Fn: typeof handle401 = handle401 +): Middleware => { + const provider = new ConformanceOAuthProvider( + 'http://localhost:3000/callback', + { + client_name: clientName, + redirect_uris: ['http://localhost:3000/callback'] + } + ); + return (next: FetchLike) => { + return async ( + input: string | URL, + init?: RequestInit + ): Promise => { + const makeRequest = async (): Promise => { + const headers = new Headers(init?.headers); + + // Add authorization header if tokens are available + const tokens = await provider.tokens(); + if (tokens) { + headers.set('Authorization', `Bearer ${tokens.access_token}`); + } + + return await next(input, { ...init, headers }); + }; + + let response = await makeRequest(); + + // Handle 401 responses by attempting re-authentication + if (response.status === 401) { + const serverUrl = + baseUrl || + (typeof input === 'string' ? new URL(input).origin : input.origin); + await handle401Fn(response, provider, next, serverUrl); + + response = await makeRequest(); + } + + // If we still have a 401 after re-auth attempt, throw an error + if (response.status === 401) { + const url = typeof input === 'string' ? input : input.toString(); + throw new UnauthorizedError(`Authentication failed for ${url}`); + } + + return response; + }; + }; +}; diff --git a/src/scenarios/client/auth/basic-dcr.test.ts b/src/scenarios/client/auth/basic-dcr.test.ts index 05fba3f..3f18e1f 100644 --- a/src/scenarios/client/auth/basic-dcr.test.ts +++ b/src/scenarios/client/auth/basic-dcr.test.ts @@ -1,5 +1,8 @@ import { describe, test } from '@jest/globals'; -import { runClientAgainstScenario } from './helpers/testClient.js'; +import { + runClientAgainstScenario, + SpawnedClientRunner +} from './test_helpers/testClient.js'; import path from 'path'; describe('PRM Path-Based Discovery', () => { @@ -8,6 +11,19 @@ describe('PRM Path-Based Discovery', () => { process.cwd(), 'examples/clients/typescript/auth-test.ts' ); - await runClientAgainstScenario(clientPath, 'auth/basic-dcr'); + const runner = new SpawnedClientRunner(clientPath); + await runClientAgainstScenario(runner, 'auth/basic-dcr'); + }); + + test('bad client requests root PRM location', async () => { + const clientPath = path.join( + process.cwd(), + 'examples/clients/typescript/auth-test-broken1.ts' + ); + const runner = new SpawnedClientRunner(clientPath); + await runClientAgainstScenario(runner, 'auth/basic-dcr', [ + // There will be other failures, but this is the one that matters + 'prm-priority-order' + ]); }); }); diff --git a/src/scenarios/client/auth/basic-dcr.ts b/src/scenarios/client/auth/basic-dcr.ts index 579df5a..12d35a2 100644 --- a/src/scenarios/client/auth/basic-dcr.ts +++ b/src/scenarios/client/auth/basic-dcr.ts @@ -3,6 +3,7 @@ import { ScenarioUrls } from '../../../types.js'; import { createAuthServer } from './helpers/createAuthServer.js'; import { createServer } from './helpers/createServer.js'; import { ServerLifecycle } from './helpers/serverLifecycle.js'; +import { Request, Response } from 'express'; export class AuthBasicDCRScenario implements Scenario { name = 'auth-basic-dcr'; @@ -25,6 +26,38 @@ export class AuthBasicDCRScenario implements Scenario { () => this.baseUrl, () => this.authBaseUrl ); + + // For this scenario, reject PRM requests at root location since we have the path-based PRM. + app.get( + '/.well-known/oauth-protected-resource', + (req: Request, res: Response) => { + this.checks.push({ + id: 'prm-priority-order', + name: 'PRM Priority Order', + description: + 'Client requested PRM metadata at root location on a server with path-based PRM', + status: 'FAILURE', + timestamp: new Date().toISOString(), + specReferences: [ + { + id: 'mcp-authorization-prm', + url: 'https://modelcontextprotocol.io/specification/draft/basic/authorization#protected-resource-metadata-discovery-requirements' + } + ], + details: { + url: req.url, + path: req.path + } + }); + + // Return 404 to indicate PRM is not available at root location + res.status(404).json({ + error: 'not_found', + error_description: 'PRM metadata not available at root location' + }); + } + ); + this.baseUrl = await this.server.start(app); return { serverUrl: `${this.baseUrl}/mcp` }; diff --git a/src/scenarios/client/auth/basic-metadata-var1.test.ts b/src/scenarios/client/auth/basic-metadata-var1.test.ts index 99e3b0a..4cdbb01 100644 --- a/src/scenarios/client/auth/basic-metadata-var1.test.ts +++ b/src/scenarios/client/auth/basic-metadata-var1.test.ts @@ -1,5 +1,8 @@ import { describe, test } from '@jest/globals'; -import { runClientAgainstScenario } from './helpers/testClient.js'; +import { + runClientAgainstScenario, + SpawnedClientRunner +} from './test_helpers/testClient.js'; import path from 'path'; describe('OAuth Metadata at OpenID Configuration Path', () => { @@ -8,6 +11,7 @@ describe('OAuth Metadata at OpenID Configuration Path', () => { process.cwd(), 'examples/clients/typescript/auth-test.ts' ); - await runClientAgainstScenario(clientPath, 'auth/basic-metadata-var1'); + const runner = new SpawnedClientRunner(clientPath); + await runClientAgainstScenario(runner, 'auth/basic-metadata-var1'); }); }); diff --git a/src/scenarios/client/auth/helpers/testClient.ts b/src/scenarios/client/auth/test_helpers/testClient.ts similarity index 67% rename from src/scenarios/client/auth/helpers/testClient.ts rename to src/scenarios/client/auth/test_helpers/testClient.ts index 6fa7f50..47a2891 100644 --- a/src/scenarios/client/auth/helpers/testClient.ts +++ b/src/scenarios/client/auth/test_helpers/testClient.ts @@ -1,26 +1,29 @@ -import { getScenario } from '../../../../scenarios/index.js'; +import { getScenario } from '../../../index.js'; import { spawn } from 'child_process'; const CLIENT_TIMEOUT = 10000; // 10 seconds for client to complete -export async function runClientAgainstScenario( - clientPath: string, - scenarioName: string, - expectedFailureSlugs: string[] = [] -): Promise { - const scenario = getScenario(scenarioName); - if (!scenario) { - throw new Error(`Scenario ${scenarioName} not found`); - } +/** + * Represents a client that can be executed against a scenario. + * Implementations can run client code inline or by spawning a process. + */ +export interface ClientRunner { + /** + * Run the client against the given server URL. + * Should reject if the client fails. + */ + run(serverUrl: string): Promise; +} - // Start the scenario server - const urls = await scenario.start(); - const serverUrl = urls.serverUrl; +/** + * Client runner that spawns a shell process to execute a client file. + */ +export class SpawnedClientRunner implements ClientRunner { + constructor(private clientPath: string) {} - try { - // Run the client + async run(serverUrl: string): Promise { await new Promise((resolve, reject) => { - const clientProcess = spawn('npx', ['tsx', clientPath, serverUrl], { + const clientProcess = spawn('npx', ['tsx', this.clientPath, serverUrl], { stdio: ['ignore', 'pipe', 'pipe'] }); @@ -66,6 +69,50 @@ export async function runClientAgainstScenario( ); }); }); + } +} + +/** + * Client runner that executes a client function inline without spawning a shell. + */ +export class InlineClientRunner implements ClientRunner { + constructor(private clientFn: (serverUrl: string) => Promise) {} + + async run(serverUrl: string): Promise { + await this.clientFn(serverUrl); + } +} + +export async function runClientAgainstScenario( + clientRunner: ClientRunner | string, + scenarioName: string, + expectedFailureSlugs: string[] = [] +): Promise { + // Handle backward compatibility: if string is passed, treat as file path + const runner = + typeof clientRunner === 'string' + ? new SpawnedClientRunner(clientRunner) + : clientRunner; + + const scenario = getScenario(scenarioName); + if (!scenario) { + throw new Error(`Scenario ${scenarioName} not found`); + } + + // Start the scenario server + const urls = await scenario.start(); + const serverUrl = urls.serverUrl; + + try { + // Run the client + try { + await runner.run(serverUrl); + } catch (err) { + if (expectedFailureSlugs.length === 0) { + throw err; // Unexpected failure + } + // Otherwise, expected failure - continue to checks verification + } // Get checks from the scenario const checks = scenario.getChecks(); @@ -91,13 +138,10 @@ export async function runClientAgainstScenario( // Verify that only the expected checks failed const failures = nonInfoChecks.filter((c) => c.status === 'FAILURE'); const failureSlugs = failures.map((c) => c.id); - if ( - failureSlugs.sort().join(',') !== expectedFailureSlugs.sort().join(',') - ) { - throw new Error( - `Expected failures ${expectedFailureSlugs.sort().join(', ')} but got ${failureSlugs.sort().join(', ')}` - ); - } + // Check that failureSlugs contains all expectedFailureSlugs + expect(failureSlugs).toEqual( + expect.arrayContaining(expectedFailureSlugs) + ); } else { // Default: expect all checks to pass const failures = nonInfoChecks.filter((c) => c.status === 'FAILURE');