Skip to content

fix: preserve canonical URL format in OAuth resource parameter per MCP auth spec #829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 182 additions & 5 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ describe("OAuth Authorization", () => {
metadata: undefined,
clientInformation: validClientInfo,
redirectUrl: "http://localhost:3000/callback",
resource: new URL("https://api.example.com/mcp-server"),
resource: "https://api.example.com/mcp-server",
}
);

Expand Down Expand Up @@ -1088,7 +1088,7 @@ describe("OAuth Authorization", () => {
authorizationCode: "code123",
codeVerifier: "verifier123",
redirectUri: "http://localhost:3000/callback",
resource: new URL("https://api.example.com/mcp-server"),
resource: "https://api.example.com/mcp-server",
});

expect(tokens).toEqual(validTokens);
Expand Down Expand Up @@ -1210,7 +1210,7 @@ describe("OAuth Authorization", () => {
authorizationCode: "code123",
codeVerifier: "verifier123",
redirectUri: "http://localhost:3000/callback",
resource: new URL("https://api.example.com/mcp-server"),
resource: "https://api.example.com/mcp-server",
fetchFn: customFetch,
});

Expand Down Expand Up @@ -1274,7 +1274,7 @@ describe("OAuth Authorization", () => {
const tokens = await refreshAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
refreshToken: "refresh123",
resource: new URL("https://api.example.com/mcp-server"),
resource: "https://api.example.com/mcp-server",
});

expect(tokens).toEqual(validTokensWithNewRefreshToken);
Expand Down Expand Up @@ -1497,6 +1497,183 @@ describe("OAuth Authorization", () => {
codeVerifier: jest.fn(),
};

describe("resource URL handling (trailing slash preservation)", () => {
beforeEach(() => {
jest.clearAllMocks();
});

it("preserves server URLs without trailing slash in resource parameter", async () => {
// Mock successful metadata discovery
mockFetch.mockImplementation((url) => {
const urlString = url.toString();
if (urlString.includes("/.well-known/oauth-protected-resource")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
resource: "https://api.example.com/mcp-server", // No trailing slash
authorization_servers: ["https://auth.example.com"],
}),
});
} else if (urlString.includes("/.well-known/oauth-authorization-server")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
response_types_supported: ["code"],
code_challenge_methods_supported: ["S256"],
}),
});
}
return Promise.resolve({ ok: false, status: 404 });
});

// Mock provider methods for authorization flow
(mockProvider.clientInformation as jest.Mock).mockResolvedValue({
client_id: "test-client",
client_secret: "test-secret",
});
(mockProvider.tokens as jest.Mock).mockResolvedValue(undefined);
(mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined);
(mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined);

// Call auth with URL that has no trailing slash
const result = await auth(mockProvider, {
serverUrl: "https://api.example.com/mcp-server", // No trailing slash
});

expect(result).toBe("REDIRECT");

// Verify the authorization URL includes the resource parameter WITHOUT trailing slash
const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0];
const authUrl: URL = redirectCall[0];
expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server"); // No trailing slash
});

it("preserves server URLs with trailing slash in resource parameter", async () => {
// Mock successful metadata discovery
mockFetch.mockImplementation((url) => {
const urlString = url.toString();
if (urlString.includes("/.well-known/oauth-protected-resource")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
resource: "https://api.example.com/mcp-server/", // With trailing slash
authorization_servers: ["https://auth.example.com"],
}),
});
} else if (urlString.includes("/.well-known/oauth-authorization-server")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
response_types_supported: ["code"],
code_challenge_methods_supported: ["S256"],
}),
});
}
return Promise.resolve({ ok: false, status: 404 });
});

// Mock provider methods
(mockProvider.clientInformation as jest.Mock).mockResolvedValue({
client_id: "test-client",
client_secret: "test-secret",
});
(mockProvider.tokens as jest.Mock).mockResolvedValue(undefined);
(mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined);
(mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined);

// Call auth with URL that has trailing slash
const result = await auth(mockProvider, {
serverUrl: "https://api.example.com/mcp-server/", // With trailing slash
});

expect(result).toBe("REDIRECT");

// Verify the authorization URL includes the resource parameter WITH trailing slash
const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0];
const authUrl: URL = redirectCall[0];
expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server/"); // With trailing slash
});

it("handles token exchange with preserved resource URL format", async () => {
// Mock successful metadata discovery and token exchange
mockFetch.mockImplementation((url) => {
const urlString = url.toString();

if (urlString.includes("/.well-known/oauth-protected-resource")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
resource: "https://api.example.com/mcp-server", // No trailing slash
authorization_servers: ["https://auth.example.com"],
}),
});
} else if (urlString.includes("/.well-known/oauth-authorization-server")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
response_types_supported: ["code"],
code_challenge_methods_supported: ["S256"],
}),
});
} else if (urlString.includes("/token")) {
return Promise.resolve({
ok: true,
status: 200,
json: async () => ({
access_token: "access123",
token_type: "Bearer",
expires_in: 3600,
refresh_token: "refresh123",
}),
});
}

return Promise.resolve({ ok: false, status: 404 });
});

// Mock provider methods for token exchange
(mockProvider.clientInformation as jest.Mock).mockResolvedValue({
client_id: "test-client",
client_secret: "test-secret",
});
(mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier");
(mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined);

// Call auth with authorization code and URL without trailing slash
const result = await auth(mockProvider, {
serverUrl: "https://api.example.com/mcp-server", // No trailing slash
authorizationCode: "auth-code-123",
});

expect(result).toBe("AUTHORIZED");

// Find the token exchange call and verify resource parameter format
const tokenCall = mockFetch.mock.calls.find(call =>
call[0].toString().includes("/token")
);
expect(tokenCall).toBeDefined();

const body = tokenCall![1].body as URLSearchParams;
expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); // No trailing slash added
expect(body.get("code")).toBe("auth-code-123");
});
});

beforeEach(() => {
jest.clearAllMocks();
});
Expand Down Expand Up @@ -1829,7 +2006,7 @@ describe("OAuth Authorization", () => {

// Verify custom validation method was called
expect(mockValidateResourceURL).toHaveBeenCalledWith(
new URL("https://api.example.com/mcp-server"),
"https://api.example.com/mcp-server",
"https://different-resource.example.com/mcp-server"
);
});
Expand Down
24 changes: 12 additions & 12 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export interface OAuthClientProvider {
*
* Implementations must verify the returned resource matches the MCP server.
*/
validateResourceURL?(serverUrl: string | URL, resource?: string): Promise<URL | undefined>;
validateResourceURL?(serverUrl: string, resource?: string): Promise<string | undefined>;

/**
* If implemented, provides a way for the client to invalidate (e.g. delete) the specified
Expand Down Expand Up @@ -281,7 +281,7 @@ export async function parseErrorResponse(input: Response | string): Promise<OAut
export async function auth(
provider: OAuthClientProvider,
options: {
serverUrl: string | URL;
serverUrl: string;
authorizationCode?: string;
scope?: string;
resourceMetadataUrl?: URL;
Expand Down Expand Up @@ -312,7 +312,7 @@ async function authInternal(
resourceMetadataUrl,
fetchFn,
}: {
serverUrl: string | URL;
serverUrl: string;
authorizationCode?: string;
scope?: string;
resourceMetadataUrl?: URL;
Expand All @@ -339,7 +339,7 @@ async function authInternal(
authorizationServerUrl = serverUrl;
}

const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata);
const resource: string | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata);

const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, {
fetchFn,
Expand Down Expand Up @@ -427,7 +427,7 @@ async function authInternal(
return "REDIRECT"
}

export async function selectResourceURL(serverUrl: string | URL, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise<URL | undefined> {
export async function selectResourceURL(serverUrl: string, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise<string | undefined> {
const defaultResource = resourceUrlFromServerUrl(serverUrl);

// If provider has custom validation, delegate to it
Expand All @@ -445,7 +445,7 @@ export async function selectResourceURL(serverUrl: string | URL, provider: OAuth
throw new Error(`Protected resource ${resourceMetadata.resource} does not match expected ${defaultResource} (or origin)`);
}
// Prefer the resource from metadata since it's what the server is telling us to request
return new URL(resourceMetadata.resource);
return resourceMetadata.resource;
}

/**
Expand Down Expand Up @@ -807,7 +807,7 @@ export async function startAuthorization(
redirectUrl: string | URL;
scope?: string;
state?: string;
resource?: URL;
resource?: string;
},
): Promise<{ authorizationUrl: URL; codeVerifier: string }> {
const responseType = "code";
Expand Down Expand Up @@ -865,7 +865,7 @@ export async function startAuthorization(
}

if (resource) {
authorizationUrl.searchParams.set("resource", resource.href);
authorizationUrl.searchParams.set("resource", resource);
}

return { authorizationUrl, codeVerifier };
Expand Down Expand Up @@ -900,7 +900,7 @@ export async function exchangeAuthorization(
authorizationCode: string;
codeVerifier: string;
redirectUri: string | URL;
resource?: URL;
resource?: string;
addClientAuthentication?: OAuthClientProvider["addClientAuthentication"];
fetchFn?: FetchLike;
},
Expand Down Expand Up @@ -943,7 +943,7 @@ export async function exchangeAuthorization(
}

if (resource) {
params.set("resource", resource.href);
params.set("resource", resource);
}

const response = await (fetchFn ?? fetch)(tokenUrl, {
Expand Down Expand Up @@ -984,7 +984,7 @@ export async function refreshAuthorization(
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformation;
refreshToken: string;
resource?: URL;
resource?: string;
addClientAuthentication?: OAuthClientProvider["addClientAuthentication"];
fetchFn?: FetchLike;
}
Expand Down Expand Up @@ -1027,7 +1027,7 @@ export async function refreshAuthorization(
}

if (resource) {
params.set("resource", resource.href);
params.set("resource", resource);
}

const response = await (fetchFn ?? fetch)(tokenUrl, {
Expand Down
6 changes: 3 additions & 3 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export class SSEClientTransport implements Transport {

let result: AuthResult;
try {
result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
} catch (error) {
this.onerror?.(error as Error);
throw error;
Expand Down Expand Up @@ -218,7 +218,7 @@ export class SSEClientTransport implements Transport {
throw new UnauthorizedError("No auth provider");
}

const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
const result = await auth(this._authProvider, { serverUrl: this._url.href, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError("Failed to authorize");
}
Expand Down Expand Up @@ -252,7 +252,7 @@ export class SSEClientTransport implements Transport {

this._resourceMetadataUrl = extractResourceMetadataUrl(response);

const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
const result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError();
}
Expand Down
6 changes: 3 additions & 3 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ export class StreamableHTTPClientTransport implements Transport {

let result: AuthResult;
try {
result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
} catch (error) {
this.onerror?.(error as Error);
throw error;
Expand Down Expand Up @@ -392,7 +392,7 @@ export class StreamableHTTPClientTransport implements Transport {
throw new UnauthorizedError("No auth provider");
}

const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
const result = await auth(this._authProvider, { serverUrl: this._url.href, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError("Failed to authorize");
}
Expand Down Expand Up @@ -440,7 +440,7 @@ export class StreamableHTTPClientTransport implements Transport {

this._resourceMetadataUrl = extractResourceMetadataUrl(response);

const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
const result = await auth(this._authProvider, { serverUrl: this._url.href, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError();
}
Expand Down
2 changes: 1 addition & 1 deletion src/examples/server/demoInMemoryOAuthProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ export const setupAuthServer = ({authServerUrl, mcpServerUrl, strictResource}: {

const validateResource = strictResource ? (resource?: URL) => {
if (!resource) return false;
const expectedResource = resourceUrlFromServerUrl(mcpServerUrl);
const expectedResource = resourceUrlFromServerUrl(mcpServerUrl.href);
return resource.toString() === expectedResource.toString();
} : undefined;

Expand Down
Loading
Loading