diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index c3049124e..34f63ec8f 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -924,7 +924,7 @@ describe("OAuth Authorization", () => { metadata: undefined, clientInformation: validClientInfo, redirectUrl: "http://localhost:3000/callback", - resource: new URL("https://api.example.com/mcp-server"), + resource: "https://api.example.com/mcp-server", } ); @@ -1088,7 +1088,7 @@ describe("OAuth Authorization", () => { authorizationCode: "code123", codeVerifier: "verifier123", redirectUri: "http://localhost:3000/callback", - resource: new URL("https://api.example.com/mcp-server"), + resource: "https://api.example.com/mcp-server", }); expect(tokens).toEqual(validTokens); @@ -1210,7 +1210,7 @@ describe("OAuth Authorization", () => { authorizationCode: "code123", codeVerifier: "verifier123", redirectUri: "http://localhost:3000/callback", - resource: new URL("https://api.example.com/mcp-server"), + resource: "https://api.example.com/mcp-server", fetchFn: customFetch, }); @@ -1274,7 +1274,7 @@ describe("OAuth Authorization", () => { const tokens = await refreshAuthorization("https://auth.example.com", { clientInformation: validClientInfo, refreshToken: "refresh123", - resource: new URL("https://api.example.com/mcp-server"), + resource: "https://api.example.com/mcp-server", }); expect(tokens).toEqual(validTokensWithNewRefreshToken); @@ -1497,6 +1497,183 @@ describe("OAuth Authorization", () => { codeVerifier: jest.fn(), }; + describe("resource URL handling (trailing slash preservation)", () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it("preserves server URLs without trailing slash in resource parameter", async () => { + // Mock successful metadata discovery + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://api.example.com/mcp-server", // No trailing slash + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for authorization flow + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth with URL that has no trailing slash + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", // No trailing slash + }); + + expect(result).toBe("REDIRECT"); + + // Verify the authorization URL includes the resource parameter WITHOUT trailing slash + const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; + const authUrl: URL = redirectCall[0]; + expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server"); // No trailing slash + }); + + it("preserves server URLs with trailing slash in resource parameter", async () => { + // Mock successful metadata discovery + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://api.example.com/mcp-server/", // With trailing slash + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth with URL that has trailing slash + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server/", // With trailing slash + }); + + expect(result).toBe("REDIRECT"); + + // Verify the authorization URL includes the resource parameter WITH trailing slash + const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; + const authUrl: URL = redirectCall[0]; + expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server/"); // With trailing slash + }); + + it("handles token exchange with preserved resource URL format", async () => { + // Mock successful metadata discovery and token exchange + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://api.example.com/mcp-server", // No trailing slash + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token exchange + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier"); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with authorization code and URL without trailing slash + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", // No trailing slash + authorizationCode: "auth-code-123", + }); + + expect(result).toBe("AUTHORIZED"); + + // Find the token exchange call and verify resource parameter format + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); // No trailing slash added + expect(body.get("code")).toBe("auth-code-123"); + }); + }); + beforeEach(() => { jest.clearAllMocks(); }); @@ -1829,7 +2006,7 @@ describe("OAuth Authorization", () => { // Verify custom validation method was called expect(mockValidateResourceURL).toHaveBeenCalledWith( - new URL("https://api.example.com/mcp-server"), + "https://api.example.com/mcp-server", "https://different-resource.example.com/mcp-server" ); }); diff --git a/src/client/auth.ts b/src/client/auth.ts index 56826045a..ccf7a37f0 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -119,7 +119,7 @@ export interface OAuthClientProvider { * * Implementations must verify the returned resource matches the MCP server. */ - validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; + validateResourceURL?(serverUrl: string, resource?: string): Promise; /** * If implemented, provides a way for the client to invalidate (e.g. delete) the specified @@ -281,7 +281,7 @@ export async function parseErrorResponse(input: Response | string): Promise { +export async function selectResourceURL(serverUrl: string, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise { const defaultResource = resourceUrlFromServerUrl(serverUrl); // If provider has custom validation, delegate to it @@ -445,7 +445,7 @@ export async function selectResourceURL(serverUrl: string | URL, provider: OAuth throw new Error(`Protected resource ${resourceMetadata.resource} does not match expected ${defaultResource} (or origin)`); } // Prefer the resource from metadata since it's what the server is telling us to request - return new URL(resourceMetadata.resource); + return resourceMetadata.resource; } /** @@ -807,7 +807,7 @@ export async function startAuthorization( redirectUrl: string | URL; scope?: string; state?: string; - resource?: URL; + resource?: string; }, ): Promise<{ authorizationUrl: URL; codeVerifier: string }> { const responseType = "code"; @@ -865,7 +865,7 @@ export async function startAuthorization( } if (resource) { - authorizationUrl.searchParams.set("resource", resource.href); + authorizationUrl.searchParams.set("resource", resource); } return { authorizationUrl, codeVerifier }; @@ -900,7 +900,7 @@ export async function exchangeAuthorization( authorizationCode: string; codeVerifier: string; redirectUri: string | URL; - resource?: URL; + resource?: string; addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; fetchFn?: FetchLike; }, @@ -943,7 +943,7 @@ export async function exchangeAuthorization( } if (resource) { - params.set("resource", resource.href); + params.set("resource", resource); } const response = await (fetchFn ?? fetch)(tokenUrl, { @@ -984,7 +984,7 @@ export async function refreshAuthorization( metadata?: AuthorizationServerMetadata; clientInformation: OAuthClientInformation; refreshToken: string; - resource?: URL; + resource?: string; addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; fetchFn?: FetchLike; } @@ -1027,7 +1027,7 @@ export async function refreshAuthorization( } if (resource) { - params.set("resource", resource.href); + params.set("resource", resource); } const response = await (fetchFn ?? fetch)(tokenUrl, { diff --git a/src/client/sse.ts b/src/client/sse.ts index e1c86ccdb..98dc3cbe8 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -93,7 +93,7 @@ export class SSEClientTransport implements Transport { let result: AuthResult; try { - result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); + result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); } catch (error) { this.onerror?.(error as Error); throw error; @@ -218,7 +218,7 @@ export class SSEClientTransport implements Transport { throw new UnauthorizedError("No auth provider"); } - const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); + const result = await auth(this._authProvider, { serverUrl: this._url.href, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError("Failed to authorize"); } @@ -252,7 +252,7 @@ export class SSEClientTransport implements Transport { this._resourceMetadataUrl = extractResourceMetadataUrl(response); - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); + const result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError(); } diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 12714ea44..7d005231e 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -156,7 +156,7 @@ export class StreamableHTTPClientTransport implements Transport { let result: AuthResult; try { - result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); + result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); } catch (error) { this.onerror?.(error as Error); throw error; @@ -392,7 +392,7 @@ export class StreamableHTTPClientTransport implements Transport { throw new UnauthorizedError("No auth provider"); } - const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); + const result = await auth(this._authProvider, { serverUrl: this._url.href, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError("Failed to authorize"); } @@ -440,7 +440,7 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = extractResourceMetadataUrl(response); - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); + const result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError(); } diff --git a/src/examples/server/demoInMemoryOAuthProvider.ts b/src/examples/server/demoInMemoryOAuthProvider.ts index c83748d35..c6ab2e729 100644 --- a/src/examples/server/demoInMemoryOAuthProvider.ts +++ b/src/examples/server/demoInMemoryOAuthProvider.ts @@ -153,7 +153,7 @@ export const setupAuthServer = ({authServerUrl, mcpServerUrl, strictResource}: { const validateResource = strictResource ? (resource?: URL) => { if (!resource) return false; - const expectedResource = resourceUrlFromServerUrl(mcpServerUrl); + const expectedResource = resourceUrlFromServerUrl(mcpServerUrl.href); return resource.toString() === expectedResource.toString(); } : undefined; diff --git a/src/shared/auth-utils.test.ts b/src/shared/auth-utils.test.ts index c1fa7bdf1..8dc7f6ab3 100644 --- a/src/shared/auth-utils.test.ts +++ b/src/shared/auth-utils.test.ts @@ -3,28 +3,98 @@ import { resourceUrlFromServerUrl, checkResourceAllowed } from './auth-utils.js' describe('auth-utils', () => { describe('resourceUrlFromServerUrl', () => { it('should remove fragments', () => { - expect(resourceUrlFromServerUrl(new URL('https://example.com/path#fragment')).href).toBe('https://example.com/path'); - expect(resourceUrlFromServerUrl(new URL('https://example.com#fragment')).href).toBe('https://example.com/'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1#fragment')).href).toBe('https://example.com/path?query=1'); + expect(resourceUrlFromServerUrl('https://example.com/path#fragment')).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl('https://example.com#fragment')).toBe('https://example.com'); + expect(resourceUrlFromServerUrl('https://example.com/path?query=1#fragment')).toBe('https://example.com/path?query=1'); + }); + + it('should preserve URLs without trailing slash (avoiding URL.href auto-addition)', () => { + expect(resourceUrlFromServerUrl('https://example.com')).toBe('https://example.com'); + expect(resourceUrlFromServerUrl('https://example.com/api')).toBe('https://example.com/api'); + expect(resourceUrlFromServerUrl('https://example.com/api/v1')).toBe('https://example.com/api/v1'); + + // Verify that URLs with fragments but no trailing slash also preserve this behavior + expect(resourceUrlFromServerUrl('https://example.com/api#fragment')).toBe('https://example.com/api'); + expect(resourceUrlFromServerUrl('https://example.com/api/v1#fragment')).toBe('https://example.com/api/v1'); + }); + + it('should preserve URLs with trailing slash exactly as-is', () => { + // URLs that already have trailing slash should keep it + expect(resourceUrlFromServerUrl('https://example.com/')).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl('https://example.com/api/')).toBe('https://example.com/api/'); + expect(resourceUrlFromServerUrl('https://example.com/api/v1/')).toBe('https://example.com/api/v1/'); + + // With fragments + expect(resourceUrlFromServerUrl('https://example.com/api/#fragment')).toBe('https://example.com/api/'); + expect(resourceUrlFromServerUrl('https://example.com/api/v1/#fragment')).toBe('https://example.com/api/v1/'); }); it('should return URL unchanged if no fragment', () => { - expect(resourceUrlFromServerUrl(new URL('https://example.com')).href).toBe('https://example.com/'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path')).href).toBe('https://example.com/path'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1')).href).toBe('https://example.com/path?query=1'); + expect(resourceUrlFromServerUrl('https://example.com')).toBe('https://example.com'); + expect(resourceUrlFromServerUrl('https://example.com/path')).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl('https://example.com/path?query=1')).toBe('https://example.com/path?query=1'); }); it('should keep everything else unchanged', () => { - // Case sensitivity preserved - expect(resourceUrlFromServerUrl(new URL('https://EXAMPLE.COM/PATH')).href).toBe('https://example.com/PATH'); + // Case sensitivity preserved - URLs are NOT normalized, kept as-is + expect(resourceUrlFromServerUrl('https://EXAMPLE.COM/PATH')).toBe('https://EXAMPLE.COM/PATH'); // Ports preserved - expect(resourceUrlFromServerUrl(new URL('https://example.com:443/path')).href).toBe('https://example.com/path'); - expect(resourceUrlFromServerUrl(new URL('https://example.com:8080/path')).href).toBe('https://example.com:8080/path'); + expect(resourceUrlFromServerUrl('https://example.com:443/path')).toBe('https://example.com:443/path'); + expect(resourceUrlFromServerUrl('https://example.com:8080/path')).toBe('https://example.com:8080/path'); // Query parameters preserved - expect(resourceUrlFromServerUrl(new URL('https://example.com?foo=bar&baz=qux')).href).toBe('https://example.com/?foo=bar&baz=qux'); + expect(resourceUrlFromServerUrl('https://example.com?foo=bar&baz=qux')).toBe('https://example.com?foo=bar&baz=qux'); // Trailing slashes preserved - expect(resourceUrlFromServerUrl(new URL('https://example.com/')).href).toBe('https://example.com/'); - expect(resourceUrlFromServerUrl(new URL('https://example.com/path/')).href).toBe('https://example.com/path/'); + expect(resourceUrlFromServerUrl('https://example.com/')).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl('https://example.com/path/')).toBe('https://example.com/path/'); + }); + + it('should demonstrate the difference from URL.href behavior', () => { + // Demonstrate that using URL.href would incorrectly add trailing slashes + const testUrls = [ + 'https://example.com', + 'https://example.com/api', + 'https://example.com/api/v1' + ]; + + testUrls.forEach(url => { + const urlObj = new URL(url); + // URL.href would add a trailing slash for domain-only URLs + if (url === 'https://example.com') { + expect(urlObj.href).toBe('https://example.com/'); // URL.href adds trailing slash + expect(resourceUrlFromServerUrl(url)).toBe('https://example.com'); // Our implementation preserves original + } else { + expect(urlObj.href).toBe(url); // URL.href keeps path URLs as-is + expect(resourceUrlFromServerUrl(url)).toBe(url); // Our implementation also preserves original + } + }); + }); + + it('should handle edge cases correctly', () => { + // Domain with port but no path + expect(resourceUrlFromServerUrl('https://example.com:8080')).toBe('https://example.com:8080'); + + // Domain with query parameters but no path + expect(resourceUrlFromServerUrl('https://example.com?param=value')).toBe('https://example.com?param=value'); + + // Complex URL with all components + expect(resourceUrlFromServerUrl('https://user:pass@example.com:8080/path?query=value#fragment')) + .toBe('https://user:pass@example.com:8080/path?query=value'); + + // IPv6 address + expect(resourceUrlFromServerUrl('https://[::1]:8080/path#fragment')) + .toBe('https://[::1]:8080/path'); + + // Empty fragment (just #) + expect(resourceUrlFromServerUrl('https://example.com/path#')) + .toBe('https://example.com/path'); + + // Fragment with query-like content + expect(resourceUrlFromServerUrl('https://example.com/path#section?param=1')) + .toBe('https://example.com/path'); + + // Schema with uppercase (should preserve case) + expect(resourceUrlFromServerUrl('HTTPS://EXAMPLE.COM/PATH#fragment')) + .toBe('HTTPS://EXAMPLE.COM/PATH'); }); }); diff --git a/src/shared/auth-utils.ts b/src/shared/auth-utils.ts index 97a77c01d..f967296d2 100644 --- a/src/shared/auth-utils.ts +++ b/src/shared/auth-utils.ts @@ -7,10 +7,9 @@ * RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". * Keeps everything else unchanged (scheme, domain, port, path, query). */ -export function resourceUrlFromServerUrl(url: URL | string ): URL { - const resourceURL = typeof url === "string" ? new URL(url) : new URL(url.href); - resourceURL.hash = ''; // Remove fragment - return resourceURL; +export function resourceUrlFromServerUrl(url: string ): string { + const hashIndex = url.indexOf('#'); + return hashIndex >= 0 ? url.substring(0, hashIndex) : url; } /**