diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index f28163d14..909406459 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -1635,6 +1635,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1705,6 +1706,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier"); (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); @@ -1773,6 +1775,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.tokens as jest.Mock).mockResolvedValue({ access_token: "old-access", @@ -1841,6 +1844,7 @@ describe("OAuth Authorization", () => { (providerWithCustomValidation.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (providerWithCustomValidation.tokens as jest.Mock).mockResolvedValue(undefined); (providerWithCustomValidation.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1896,6 +1900,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1954,6 +1959,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -2021,6 +2027,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier"); (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); @@ -2086,6 +2093,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.tokens as jest.Mock).mockResolvedValue({ access_token: "old-access", @@ -2149,6 +2157,7 @@ describe("OAuth Authorization", () => { (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: "test-client", client_secret: "test-secret", + redirect_uris: ["http://localhost:3000/callback"], }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -2209,6 +2218,7 @@ describe("OAuth Authorization", () => { clientInformation: jest.fn().mockResolvedValue({ client_id: "client123", client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], }), tokens: jest.fn().mockResolvedValue(undefined), saveTokens: jest.fn(), diff --git a/src/client/auth.ts b/src/client/auth.ts index fcc320f17..af863de97 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -2,7 +2,6 @@ import pkceChallenge from "pkce-challenge"; import { LATEST_PROTOCOL_VERSION } from "../types.js"; import { OAuthClientMetadata, - OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, @@ -51,7 +50,7 @@ export interface OAuthClientProvider { * server, or returns `undefined` if the client is not registered with the * server. */ - clientInformation(): OAuthClientInformation | undefined | Promise; + clientInformation(): OAuthClientInformationFull | undefined | Promise; /** * If implemented, this permits the OAuth client to dynamically register with @@ -139,6 +138,10 @@ export class UnauthorizedError extends Error { type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; +function isClientAuthMethod(method: string): method is ClientAuthMethod { + return ["client_secret_basic", "client_secret_post", "none"].includes(method); +} + /** * Determines the best client authentication method to use based on server support and client configuration. * @@ -152,7 +155,7 @@ type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; * @returns The selected authentication method */ function selectClientAuthMethod( - clientInformation: OAuthClientInformation, + clientInformation: OAuthClientInformationFull, supportedMethods: string[] ): ClientAuthMethod { const hasClientSecret = clientInformation.client_secret !== undefined; @@ -162,6 +165,15 @@ function selectClientAuthMethod( return hasClientSecret ? "client_secret_post" : "none"; } + // Prefer the method returned by the server during client registration if valid and supported + if ( + clientInformation.token_endpoint_auth_method && + isClientAuthMethod(clientInformation.token_endpoint_auth_method) && + supportedMethods.includes(clientInformation.token_endpoint_auth_method) + ) { + return clientInformation.token_endpoint_auth_method; + } + // Try methods in priority order (most secure first) if (hasClientSecret && supportedMethods.includes("client_secret_basic")) { return "client_secret_basic"; @@ -195,7 +207,7 @@ function selectClientAuthMethod( */ function applyClientAuthentication( method: ClientAuthMethod, - clientInformation: OAuthClientInformation, + clientInformation: OAuthClientInformationFull, headers: Headers, params: URLSearchParams ): void { @@ -809,7 +821,7 @@ export async function startAuthorization( resource, }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationFull; redirectUrl: string | URL; scope?: string; state?: string; @@ -902,7 +914,7 @@ export async function exchangeAuthorization( fetchFn, }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationFull; authorizationCode: string; codeVerifier: string; redirectUri: string | URL; @@ -988,7 +1000,7 @@ export async function refreshAuthorization( fetchFn, }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationFull; refreshToken: string; resource?: URL; addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 4fce9976f..3d5a2565f 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -363,7 +363,7 @@ describe("SSEClientTransport", () => { mockAuthProvider = { get redirectUrl() { return "http://localhost/callback"; }, get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, - clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), + clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret", redirect_uris: ["http://localhost/callback"] })), tokens: jest.fn(), saveTokens: jest.fn(), redirectToAuthorization: jest.fn(), @@ -1140,7 +1140,8 @@ describe("SSEClientTransport", () => { const clientInfo = config.clientRegistered ? { client_id: "test-client-id", - client_secret: "test-client-secret" + client_secret: "test-client-secret", + redirect_uris: ["http://localhost/callback"], } : undefined; return { diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index af5cd8f16..de5700444 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -12,7 +12,7 @@ describe("StreamableHTTPClientTransport", () => { mockAuthProvider = { get redirectUrl() { return "http://localhost/callback"; }, get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, - clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), + clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret", redirect_uris: ["http://localhost/callback"] })), tokens: jest.fn(), saveTokens: jest.fn(), redirectToAuthorization: jest.fn(), diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index b7388384a..8fbec9c5b 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -6,7 +6,7 @@ import { URL } from 'node:url'; import { exec } from 'node:child_process'; import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; -import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; +import { OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; import { CallToolRequest, ListToolsRequest, @@ -49,7 +49,7 @@ class InMemoryOAuthClientProvider implements OAuthClientProvider { return this._clientMetadata; } - clientInformation(): OAuthClientInformation | undefined { + clientInformation(): OAuthClientInformationFull | undefined { return this._clientInformation; }