diff --git a/.gitignore b/.gitignore index 694735b68..4b13b8453 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ out .DS_Store dist/ + +# ide +.idea/ diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index f28163d14..1b6acef16 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -1,19 +1,20 @@ import { LATEST_PROTOCOL_VERSION } from '../types.js'; import { - discoverOAuthMetadata, - discoverAuthorizationServerMetadata, - buildDiscoveryUrls, - startAuthorization, - exchangeAuthorization, - refreshAuthorization, - registerClient, - discoverOAuthProtectedResourceMetadata, - extractResourceMetadataUrl, - auth, - type OAuthClientProvider, + discoverOAuthMetadata, + discoverAuthorizationServerMetadata, + buildDiscoveryUrls, + startAuthorization, + exchangeAuthorization, + refreshAuthorization, + registerClient, + discoverOAuthProtectedResourceMetadata, + extractResourceMetadataUrl, + auth, + type OAuthClientProvider, startClientCredentialAuthorization, } from "./auth.js"; import {ServerError} from "../server/auth/errors.js"; import { AuthorizationServerMetadata } from '../shared/auth.js'; +import {describe} from "@jest/globals"; // Mock fetch globally const mockFetch = jest.fn(); @@ -1265,6 +1266,169 @@ describe("OAuth Authorization", () => { }); }); + describe("startClientCredentialAuthorization", () => { + const validTokens = { + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }; + + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"] + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + it("retrieve tokens with client credentials", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await startClientCredentialAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + scope: "openid", + resource: new URL("https://api.example.com/mcp-server"), + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }), + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("client_credentials"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + expect(body.get("scope")).toBe("openid"); + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + }); + + it("retrieve tokens with client credentials with auth metadata", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await startClientCredentialAuthorization("https://auth2.example.com", { + metadata: validMetadata, + clientInformation: validClientInfo, + scope: "openid", + resource: new URL("https://api.example.com/mcp-server"), + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }), + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("client_credentials"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + expect(body.get("scope")).toBe("openid"); + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + }); + + it("validates token response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: "access123", + }), + }); + + await expect( + startClientCredentialAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + scope: "openid", + resource: new URL("https://api.example.com/mcp-server"), + }) + ).rejects.toThrow(); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Authorization failed").toResponseObject(), + { status: 400 } + ) + ); + + await expect( + startClientCredentialAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + scope: "openid", + resource: new URL("https://api.example.com/mcp-server"), + }) + ).rejects.toThrow("Authorization failed"); + }); + + it("supports overriding the fetch function used for requests", async () => { + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await startClientCredentialAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + scope: "openid", + resource: new URL("https://api.example.com/mcp-server"), + fetchFn: customFetch + }); + + expect(tokens).toEqual(validTokens); + expect(customFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + }), + }) + ); + + const body = customFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("client_credentials"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + expect(body.get("scope")).toBe("openid"); + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + }); + }) + describe("refreshAuthorization", () => { const validTokens = { access_token: "newaccess123", @@ -1506,7 +1670,7 @@ describe("OAuth Authorization", () => { }); }); - describe("auth function", () => { + describe("auth function - authorization flow", () => { const mockProvider: OAuthClientProvider = { get redirectUrl() { return "http://localhost:3000/callback"; }, get clientMetadata() { @@ -2234,125 +2398,705 @@ describe("OAuth Authorization", () => { }); }); - describe("exchangeAuthorization with multiple client authentication methods", () => { - const validTokens = { - access_token: "access123", - token_type: "Bearer", - expires_in: 3600, - refresh_token: "refresh123", - }; - - const validClientInfo = { - client_id: "client123", - client_secret: "secret123", - redirect_uris: ["http://localhost:3000/callback"], - client_name: "Test Client", + describe("auth function - credential flow", () => { + const mockProvider: OAuthClientProvider = { + get redirectUrl() { return "http://localhost:3000/callback"; }, + get clientMetadata() { + return { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + }, + authFlow() { return "client_credentials"; }, + clientInformation: jest.fn(), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), }; - const metadataWithBasicOnly = { - issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/auth", - token_endpoint: "https://auth.example.com/token", - response_types_supported: ["code"], - code_challenge_methods_supported: ["S256"], - token_endpoint_auth_methods_supported: ["client_secret_basic"], - }; + beforeEach(() => { + jest.clearAllMocks(); + }); - const metadataWithPostOnly = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["client_secret_post"], - }; + it("falls back to /.well-known/oauth-authorization-server when no protected-resource-metadata", async () => { + // Setup: First call to protected resource metadata fails (404) + // Second call to auth server metadata succeeds + let callCount = 0; + mockFetch.mockImplementation((url) => { + callCount++; - const metadataWithNoneOnly = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["none"], - }; + const urlString = url.toString(); - const metadataWithAllBuiltinMethods = { - ...metadataWithBasicOnly, - token_endpoint_auth_methods_supported: ["client_secret_basic", "client_secret_post", "none"], - }; + if (callCount === 1 && urlString.includes("/.well-known/oauth-protected-resource")) { + // First call - protected resource metadata fails with 404 + return Promise.resolve({ + ok: false, + status: 404, + }); + } else if (callCount === 2 && urlString.includes("/.well-known/oauth-authorization-server")) { + // Second call - auth server metadata succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (callCount === 3 && urlString.includes("/register")) { + // Third call - client registration succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + client_id: "test-client-id", + client_secret: "test-client-secret", + client_id_issued_at: 1612137600, + client_secret_expires_at: 1612224000, + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }), + }); + } else if (callCount ===4 && urlString.includes('/token')) { + // Fourth call - token retrieval succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "test-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "test-refresh-token", + }), + }) + } - it("uses HTTP Basic authentication when client_secret_basic is supported", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, + return Promise.reject(new Error(`Unexpected fetch call: ${urlString}`)); }); - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithBasicOnly, - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue(undefined); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + mockProvider.saveClientInformation = jest.fn(); - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; + // Call the auth function + const result = await auth(mockProvider, { + serverUrl: "https://resource.example.com", + }); - // Check Authorization header - const authHeader = request.headers.get("Authorization"); - const expected = "Basic " + btoa("client123:secret123"); - expect(authHeader).toBe(expected); + // Verify the result + expect(result).toBe("AUTHORIZED"); - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBeNull(); - expect(body.get("client_secret")).toBeNull(); - }); + // Verify the sequence of calls + expect(mockFetch).toHaveBeenCalledTimes(4); - it("includes credentials in request body when client_secret_post is supported", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, - }); + // First call should be to protected resource metadata + expect(mockFetch.mock.calls[0][0].toString()).toBe( + "https://resource.example.com/.well-known/oauth-protected-resource" + ); - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithPostOnly, - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", - }); + // Second call should be to oauth metadata + expect(mockFetch.mock.calls[1][0].toString()).toBe( + "https://resource.example.com/.well-known/oauth-authorization-server" + ); + }); - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; + it("includes resource in token request", async () => { + // Mock successful metadata discovery and token exchange - need protected resource metadata + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); - // Check no Authorization header - expect(request.headers.get("Authorization")).toBeNull(); + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://api.example.com/mcp-server", + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }), + }); + } - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBe("client123"); - expect(body.get("client_secret")).toBe("secret123"); - }); + return Promise.resolve({ ok: false, status: 404 }); + }); - it("it picks client_secret_basic when all builtin methods are supported", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => validTokens, + // Mock provider methods for token request + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); - const tokens = await exchangeAuthorization("https://auth.example.com", { - metadata: metadataWithAllBuiltinMethods, - clientInformation: validClientInfo, - authorizationCode: "code123", - redirectUri: "http://localhost:3000/callback", - codeVerifier: "verifier123", + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", }); - expect(tokens).toEqual(validTokens); - const request = mockFetch.mock.calls[0][1]; + expect(result).toBe("AUTHORIZED"); - // Check Authorization header - should use Basic auth as it's the most secure - const authHeader = request.headers.get("Authorization"); - const expected = "Basic " + btoa("client123:secret123"); - expect(authHeader).toBe(expected); + // Find the token exchange call + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); - // Credentials should not be in body when using Basic auth - const body = request.body as URLSearchParams; - expect(body.get("client_id")).toBeNull(); + const body = tokenCall![1].body as URLSearchParams; + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + }); + + it("includes resource in token refresh", async () => { + // Mock successful metadata discovery and token refresh - need protected resource metadata + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://api.example.com/mcp-server", + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "new-access123", + token_type: "Bearer", + expires_in: 3600, + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token refresh + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue({ + access_token: "old-access", + refresh_token: "refresh123", + }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with existing tokens (should trigger refresh) + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("AUTHORIZED"); + + // Find the token refresh call + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + }); + + it("skips default PRM resource validation when custom validateResourceURL is provided", async () => { + const mockValidateResourceURL = jest.fn().mockResolvedValue(undefined); + const providerWithCustomValidation = { + ...mockProvider, + validateResourceURL: mockValidateResourceURL, + }; + + // Mock protected resource metadata with mismatched resource URL + // This would normally throw an error in default validation, but should be skipped + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://different-resource.example.com/mcp-server", // Mismatched resource + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes('/token')) { + // Fourth call - token retrieval succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "test-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "test-refresh-token", + }), + }) + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (providerWithCustomValidation.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (providerWithCustomValidation.tokens as jest.Mock).mockResolvedValue(undefined); + (providerWithCustomValidation.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + + // Call auth - should succeed despite resource mismatch because custom validation overrides default + const result = await auth(providerWithCustomValidation, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("AUTHORIZED"); + + // Verify custom validation method was called + expect(mockValidateResourceURL).toHaveBeenCalledWith( + new URL("https://api.example.com/mcp-server"), + "https://different-resource.example.com/mcp-server" + ); + }); + + it("excludes resource parameter in token exchange when Protected Resource Metadata is not present", async () => { + // Mock metadata discovery - no protected resource metadata, but auth server metadata available + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: false, + status: 404, + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token exchange + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("AUTHORIZED"); + + // Find the token exchange call + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + // Resource parameter should not be present when PRM is not available + expect(body.has("resource")).toBe(false); + expect(body.get("grant_type")).toBe("client_credentials"); + }); + + it("excludes resource parameter in token refresh when Protected Resource Metadata is not present", async () => { + // Mock metadata discovery - no protected resource metadata, but auth server metadata available + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: false, + status: 404, + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "new-access123", + token_type: "Bearer", + expires_in: 3600, + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token refresh + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue({ + access_token: "old-access", + refresh_token: "refresh123", + }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with existing tokens (should trigger refresh) + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("AUTHORIZED"); + + // Find the token refresh call + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + // Resource parameter should not be present when PRM is not available + expect(body.has("resource")).toBe(false); + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + }); + + it("fetches AS metadata with path from serverUrl when PRM returns external AS", async () => { + // Mock PRM discovery that returns an external AS + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString === "https://my.resource.com/.well-known/oauth-protected-resource/path/name") { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://my.resource.com/", + authorization_servers: ["https://auth.example.com/oauth"], + }), + }); + } else if (urlString === "https://auth.example.com/.well-known/oauth-authorization-server/path/name") { + // Path-aware discovery on AS with path from serverUrl + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes('/token')) { + // Fourth call - token retrieval succeeds + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "test-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "test-refresh-token", + }), + }) + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + + // Call auth with serverUrl that has a path + const result = await auth(mockProvider, { + serverUrl: "https://my.resource.com/path/name", + }); + + expect(result).toBe("AUTHORIZED"); + + // Verify the correct URLs were fetched + const calls = mockFetch.mock.calls; + + // First call should be to PRM + expect(calls[0][0].toString()).toBe("https://my.resource.com/.well-known/oauth-protected-resource/path/name"); + + // Second call should be to AS metadata with the path from authorization server + expect(calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/oauth"); + }); + + it("supports overriding the fetch function used for requests", async () => { + const customFetch = jest.fn(); + + // Mock PRM discovery + customFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://resource.example.com", + authorization_servers: ["https://auth.example.com"], + }), + }); + + // Mock AS metadata discovery + customFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + + const mockProvider: OAuthClientProvider = { + get redirectUrl() { return "http://localhost:3000/callback"; }, + get clientMetadata() { + return { + client_name: "Test Client", + redirect_uris: ["http://localhost:3000/callback"], + }; + }, + clientInformation: jest.fn().mockResolvedValue({ + client_id: "client123", + client_secret: "secret123", + }), + tokens: jest.fn().mockResolvedValue(undefined), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn().mockResolvedValue("verifier123"), + }; + + const result = await auth(mockProvider, { + serverUrl: "https://resource.example.com", + fetchFn: customFetch, + }); + + expect(result).toBe("REDIRECT"); + expect(customFetch).toHaveBeenCalledTimes(2); + expect(mockFetch).not.toHaveBeenCalled(); + + // Verify custom fetch was called for PRM discovery + expect(customFetch.mock.calls[0][0].toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + + // Verify custom fetch was called for AS metadata discovery + expect(customFetch.mock.calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); + }); + }); + + describe("exchangeAuthorization with multiple client authentication methods", () => { + const validTokens = { + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const metadataWithBasicOnly = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/auth", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + token_endpoint_auth_methods_supported: ["client_secret_basic"], + }; + + const metadataWithPostOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["client_secret_post"], + }; + + const metadataWithNoneOnly = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["none"], + }; + + const metadataWithAllBuiltinMethods = { + ...metadataWithBasicOnly, + token_endpoint_auth_methods_supported: ["client_secret_basic", "client_secret_post", "none"], + }; + + it("uses HTTP Basic authentication when client_secret_basic is supported", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithBasicOnly, + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header + const authHeader = request.headers.get("Authorization"); + const expected = "Basic " + btoa("client123:secret123"); + expect(authHeader).toBe(expected); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBeNull(); + expect(body.get("client_secret")).toBeNull(); + }); + + it("includes credentials in request body when client_secret_post is supported", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithPostOnly, + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check no Authorization header + expect(request.headers.get("Authorization")).toBeNull(); + + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + }); + + it("it picks client_secret_basic when all builtin methods are supported", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + metadata: metadataWithAllBuiltinMethods, + clientInformation: validClientInfo, + authorizationCode: "code123", + redirectUri: "http://localhost:3000/callback", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + const request = mockFetch.mock.calls[0][1]; + + // Check Authorization header - should use Basic auth as it's the most secure + const authHeader = request.headers.get("Authorization"); + const expected = "Basic " + btoa("client123:secret123"); + expect(authHeader).toBe(expected); + + // Credentials should not be in body when using Basic auth + const body = request.body as URLSearchParams; + expect(body.get("client_id")).toBeNull(); expect(body.get("client_secret")).toBeNull(); }); @@ -2415,6 +3159,8 @@ describe("OAuth Authorization", () => { }); }); + describe.skip("startClientCredentialAuthorization with multiple client authentication methods", () => {}) + describe("refreshAuthorization with multiple client authentication methods", () => { const validTokens = { access_token: "newaccess123", diff --git a/src/client/auth.ts b/src/client/auth.ts index fcc320f17..9017bf4cf 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -23,6 +23,7 @@ import { } from "../server/auth/errors.js"; import { FetchLike } from "../shared/transport.js"; +type SupportedAuthorizationFlow = 'authorization_code' | 'client_credentials'; /** * Implements an end-to-end OAuth client to be used with one MCP server. * @@ -46,6 +47,11 @@ export interface OAuthClientProvider { */ state?(): string | Promise; + /** + * Returns the authorization flow to use. + */ + authFlow?(): SupportedAuthorizationFlow; + /** * Loads information about this OAuth client, as registered already with the * server, or returns `undefined` if the client is not registered with the @@ -320,6 +326,7 @@ async function authInternal( }, ): Promise { + const authFlow = provider.authFlow ? provider.authFlow() : 'authorization_code'; let resourceMetadata: OAuthProtectedResourceMetadata | undefined; let authorizationServerUrl: string | URL | undefined; try { @@ -414,6 +421,16 @@ async function authInternal( const state = provider.state ? await provider.state() : undefined; + if(authFlow === 'client_credentials') { + const newTokens = await startClientCredentialAuthorization(authorizationServerUrl, { + metadata, + clientInformation, + scope: scope || provider.clientMetadata.scope, + resource + }) + await provider.saveTokens(newTokens); + return "AUTHORIZED" + } // Start new authorization flow const { authorizationUrl, codeVerifier } = await startAuthorization(authorizationServerUrl, { metadata, @@ -877,6 +894,78 @@ export async function startAuthorization( return { authorizationUrl, codeVerifier }; } +export async function startClientCredentialAuthorization( + authorizationServerUrl: string | URL, + { + metadata, + clientInformation, + scope, + resource, + addClientAuthentication, + fetchFn, + }: { + metadata?: AuthorizationServerMetadata; + clientInformation: OAuthClientInformation; + scope?: string; + resource?: URL; + addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; + fetchFn?: FetchLike; + }, +): Promise { + const grantType = "client_credentials"; + + const tokenUrl = metadata?.token_endpoint + ? new URL(metadata.token_endpoint) + : new URL("/token", authorizationServerUrl); + + if ( + metadata?.grant_types_supported && + !metadata.grant_types_supported.includes(grantType) + ) { + throw new Error( + `Incompatible auth server: does not support grant type ${grantType}`, + ); + } + + // Exchange code for tokens + const headers = new Headers({ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }); + const params = new URLSearchParams({ + grant_type: grantType, + }); + + if (addClientAuthentication) { + addClientAuthentication(headers, params, authorizationServerUrl, metadata); + } else { + // Determine and apply client authentication method + const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? []; + const authMethod = selectClientAuthMethod(clientInformation, supportedMethods); + + applyClientAuthentication(authMethod, clientInformation, headers, params); + } + + if (resource) { + params.set("resource", resource.href); + } + + if (scope) { + params.set("scope", scope); + } + + const response = await (fetchFn ?? fetch)(tokenUrl, { + method: "POST", + headers, + body: params, + }); + + if (!response.ok) { + throw await parseErrorResponse(response); + } + + return OAuthTokensSchema.parse(await response.json()); +} /** * Exchanges an authorization code for an access token with the given server. *