diff --git a/src/lib/oauth.test.ts b/src/lib/oauth.test.ts new file mode 100644 index 0000000..3da7955 --- /dev/null +++ b/src/lib/oauth.test.ts @@ -0,0 +1,283 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { connectToRemoteServer, REASON_AUTH_NEEDED } from './utils' +import { NodeOAuthClientProvider } from './node-oauth-client-provider' +import type { OAuthProviderOptions } from './types' +import express from 'express' +import { Server } from 'http' + +describe('OAuth Authorization', () => { + let mcpServer: MockServer + let idpServer: MockServer + + beforeEach(async () => { + mcpServer = new MockServer() + await mcpServer.start() + + idpServer = new MockServer() + await idpServer.start() + }) + + afterEach(async () => { + await mcpServer.stop() + await idpServer.stop() + }) + + it('uses the protected resource metadata URL from the WWW-Authenticate response header to find authorization server', async () => { + // Setup mocked mcp server + const mcpServerUrl = mcpServer.url('/test/mcp') + mcpServer.addRoute('POST', '/test/mcp', (req, res) => { + res.status(401)// Since we are testing only the login flow, we ignore the post-auth stuff. + .header('WWW-Authenticate', `Bearer realm="mcp", resource_metadata="${mcpServer.url('/test/mcp/.well-known/oauth-protected-resource')}"`) + .json({ error: 'Unauthorized' }) + }) + mcpServer.addRoute('GET', '/test/mcp/.well-known/oauth-protected-resource', (req, res) => { + res.json({ + resource: mcpServer.url('/test/mcp'), + authorization_servers: [idpServer.url('/test/auth')] + }) + }); + + // Setup mocked idp server + idpServer.addRoute('GET', '/test/auth/.well-known/openid-configuration', (req, res) => { + res.json({ + issuer: idpServer.url('/test/auth'), + authorization_endpoint: idpServer.url('/test/auth/authorize'), + token_endpoint: idpServer.url('/test/auth/token'), + jwks_uri: idpServer.url('/test/auth/jwks'), + response_types_supported: ['code', 'token', 'id_token'], + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + code_challenge_methods_supported: ['plain', 'S256'], + token_endpoint_auth_methods_supported: ['client_secret_basic', 'none'], + scopes_supported: ['openid', 'profile', 'email'], + claims_supported: ['sub', 'iss', 'aud', 'exp', 'iat', 'name', 'email'], + }) + }); + idpServer.addRoute('POST', '/test/auth/token', (req, res) => { + res.json({ + "access_token": "abc123def456ghi789", + "token_type": "Bearer" + }) + }); + + // Setup OAuth client + const authProvider = new NodeOAuthClientProvider({ + serverUrl: mcpServerUrl, + callbackPort: 0, + host: 'localhost', + callbackPath: '/oauth/callback', + staticOAuthClientInfo: { + client_id: 'mock-client-id', + client_name: 'Mock Client', + redirect_uris: ['http://localhost/callback'], + token_endpoint_auth_method: 'none', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'] + } + }) + + // Mock auth code acquisition + vi.spyOn(authProvider, 'redirectToAuthorization').mockImplementation(() => Promise.resolve()) + const mockAuthInitializer = vi.fn().mockResolvedValue({ + waitForAuthCode: vi.fn().mockResolvedValue('mocked-auth-code'), + skipBrowserAuth: false, + }) + + // Initiate login flow + try { + const transport = await connectToRemoteServer( + null, + mcpServerUrl, + authProvider, + {}, + mockAuthInitializer, + 'http-first', + new Set([REASON_AUTH_NEEDED])/*dont do the recursion thing, because we want to stop after the first auth attempt*/ + ) + } catch (e: Error | any) { + // Since we only do one run and skip the recursion, we expect it to give up early. + expect(e).toBeInstanceOf(Error); + expect(e.message).contains("Giving up."); + } + + // Verify we successfully acquired the accesstoken + const tokens = await authProvider.tokens(); + expect(tokens?.access_token).toBe('abc123def456ghi789') + }) +}) + +// Test utility for controlling fetch responses. +class MockServer { + private app: express.Application + private server: Server | null = null + private port: number = 0 + public baseUrl: string = '' + + constructor() { + this.app = express() + this.app.use(express.json()) + } + + addRoute(method: 'GET' | 'POST' | 'PUT' | 'DELETE', path: string, handler: express.RequestHandler) { + this.app[method.toLowerCase() as 'get' | 'post' | 'put' | 'delete'](path, handler) + } + + async start(): Promise { + return new Promise((resolve, reject) => { + this.server = this.app.listen(0, 'localhost', () => { + if (this.server) { + const address = this.server.address() + if (address && typeof address === 'object') { + this.port = address.port + this.baseUrl = `http://localhost:${this.port}` + resolve() + } else { + reject(new Error('Failed to get server address')) + } + } else { + reject(new Error('Server failed to start')) + } + }) + + this.server.on('error', reject) + }) + } + + async stop(): Promise { + return new Promise((resolve, reject) => { + if (this.server) { + this.server.close((err) => { + if (err) { + reject(err) + } else { + this.server = null + resolve() + } + }) + } else { + resolve() + } + }) + } + + url(path: string): string { + return `${this.baseUrl}${path}` + } +} + +describe('Test Infrastructure: Mock Server', () => { + let mockServer: MockServer + + beforeEach(async () => { + mockServer = new MockServer() + await mockServer.start() + }) + + afterEach(async () => { + await mockServer.stop() + }) + + it('responds with custom data', async () => { + // Given a mock server with a custom route + mockServer.addRoute('GET', '/test', (req, res) => { + res.json({ message: 'Hello from mock server!' }) + }) + + // When making a fetch request to the mock server + const response = await fetch(mockServer.url('/test')) + const data = await response.json() + + // Then the mock server should respond with the expected data + expect(data).toEqual({ message: 'Hello from mock server!' }) + expect(response.status).toBe(200) + }) + + it('handles POST requests', async () => { + // Given a mock server with an echo endpoint + mockServer.addRoute('POST', '/echo', (req, res) => { + res.json({ + received: req.body, + headers: req.headers, + method: req.method + }) + }) + + // When posting data to the mock server + const testData = { test: 'data', number: 42 } + const response = await fetch(mockServer.url('/echo'), { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(testData) + }) + const responseData = await response.json() + + // Then the server should echo back the data + expect(responseData.received).toEqual(testData) + expect(responseData.method).toBe('POST') + expect(responseData.headers['content-type']).toBe('application/json') + }) + + it('can simulate different HTTP status codes', async () => { + // Given a mock server that returns different status codes + mockServer.addRoute('GET', '/success', (req, res) => { + res.status(200).json({ status: 'ok' }) + }) + + mockServer.addRoute('GET', '/error', (req, res) => { + res.status(500).json({ error: 'Internal Server Error' }) + }) + + mockServer.addRoute('GET', '/not-found', (req, res) => { + res.status(404).json({ error: 'Not Found' }) + }) + + // When making requests to different endpoints + const successResponse = await fetch(mockServer.url('/success')) + const errorResponse = await fetch(mockServer.url('/error')) + const notFoundResponse = await fetch(mockServer.url('/not-found')) + + // Then the appropriate status codes should be returned + expect(successResponse.status).toBe(200) + expect(errorResponse.status).toBe(500) + expect(notFoundResponse.status).toBe(404) + }) + + it('can simulate authentication flow', async () => { + // Given a mock server that checks for authentication + mockServer.addRoute('GET', '/protected', (req, res) => { + const authHeader = req.headers.authorization + + if (!authHeader || !authHeader.startsWith('Bearer ')) { + res.status(401).json({ error: 'Unauthorized' }) + return + } + + const token = authHeader.substring(7) // Remove 'Bearer ' prefix + + if (token === 'valid-token') { + res.json({ message: 'Access granted', user: 'test-user' }) + } else { + res.status(401).json({ error: 'Invalid token' }) + } + }) + + // When making a request without authentication + const unauthResponse = await fetch(mockServer.url('/protected')) + expect(unauthResponse.status).toBe(401) + + // When making a request with invalid token + const invalidTokenResponse = await fetch(mockServer.url('/protected'), { + headers: { Authorization: 'Bearer invalid-token' } + }) + expect(invalidTokenResponse.status).toBe(401) + + // When making a request with valid token + const validTokenResponse = await fetch(mockServer.url('/protected'), { + headers: { Authorization: 'Bearer valid-token' } + }) + expect(validTokenResponse.status).toBe(200) + + const data = await validTokenResponse.json() + expect(data).toEqual({ message: 'Access granted', user: 'test-user' }) + }) +}) diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 3824903..67e6fb1 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -308,7 +308,8 @@ export async function connectToRemoteServer( authProvider, requestInit: { headers }, }) - + + let rememberedTestTransport: any | null = null try { debugLog('Attempting to connect to remote server', { sseTransport }) @@ -325,6 +326,7 @@ export async function connectToRemoteServer( // out if we're actually talking to an HTTP server or not. debugLog('Creating test transport for HTTP-only connection test') const testTransport = new StreamableHTTPClientTransport(url, { authProvider, requestInit: { headers } }) + rememberedTestTransport = testTransport const testClient = new Client({ name: 'mcp-remote-fallback-test', version: '0.0.0' }, { capabilities: {} }) await testClient.connect(testTransport) } @@ -374,6 +376,21 @@ export async function connectToRemoteServer( stack: error.stack, }) + // Preserve the resource metadata URL if we discovered it during the test connection + try { + // Because we have this "Extremely hacky" test transport above, which with we have discovered the + // resource metadata URL if it was available, we need to "inject" the resourceMetadataUrl into the + // actual transportso used for the oauth login so it has the correct URL to use when fetching tokens. + // This is important for cases where the authorization server is on a different URL than the MCP server. + // A better solution would use only a single transport for everything and not need recursion! + let injectableTransport: any = transport; + if (!injectableTransport._resourceMetadataUrl) {//only if it wasn't already set + injectableTransport._resourceMetadataUrl = rememberedTestTransport?._resourceMetadataUrl;//will break when the transport's internals are refactored + } + } catch (e) { + debugLog('Unable to preserve resource metadata URL from testTransport') + } + // Initialize authentication on-demand debugLog('Calling authInitializer to start auth flow') const { waitForAuthCode, skipBrowserAuth } = await authInitializer()