Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
refreshAuthorization,
registerClient,
discoverOAuthProtectedResourceMetadata,
extractResourceMetadataUrl,
extractWWWAuthenticateParams,
auth,
type OAuthClientProvider
} from './auth.js';
Expand All @@ -24,7 +24,7 @@ describe('OAuth Authorization', () => {
mockFetch.mockReset();
});

describe('extractResourceMetadataUrl', () => {
describe('extractWWWAuthenticateParams', () => {
it('returns resource metadata url when present', async () => {
const resourceUrl = 'https://resource.example.com/.well-known/oauth-protected-resource';
const mockResponse = {
Expand All @@ -33,39 +33,56 @@ describe('OAuth Authorization', () => {
}
} as unknown as Response;

expect(extractResourceMetadataUrl(mockResponse)).toEqual(new URL(resourceUrl));
expect(extractWWWAuthenticateParams(mockResponse)).toEqual({ resourceMetadataUrl: new URL(resourceUrl) });
});

it('returns undefined if not bearer', async () => {
it('returns scope when present', async () => {
const scope = 'read';
const mockResponse = {
headers: {
get: jest.fn(name => (name === 'WWW-Authenticate' ? `Bearer realm="mcp", scope="${scope}"` : null))
}
} as unknown as Response;

expect(extractWWWAuthenticateParams(mockResponse)).toEqual({ scope: scope });
});

it('returns empty object if not bearer', async () => {
const resourceUrl = 'https://resource.example.com/.well-known/oauth-protected-resource';
const scope = 'read';
const mockResponse = {
headers: {
get: jest.fn(name => (name === 'WWW-Authenticate' ? `Basic realm="mcp", resource_metadata="${resourceUrl}"` : null))
get: jest.fn(name =>
name === 'WWW-Authenticate' ? `Basic realm="mcp", resource_metadata="${resourceUrl}", scope="${scope}"` : null
)
}
} as unknown as Response;

expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined();
expect(extractWWWAuthenticateParams(mockResponse)).toEqual({});
});

it('returns undefined if resource_metadata not present', async () => {
it('returns empty object if resource_metadata and scope not present', async () => {
const mockResponse = {
headers: {
get: jest.fn(name => (name === 'WWW-Authenticate' ? `Basic realm="mcp"` : null))
get: jest.fn(name => (name === 'WWW-Authenticate' ? `Bearer realm="mcp"` : null))
}
} as unknown as Response;

expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined();
expect(extractWWWAuthenticateParams(mockResponse)).toEqual({});
});

it('returns undefined on invalid url', async () => {
it('returns undefined resourceMetadataUrl on invalid url', async () => {
const resourceUrl = 'invalid-url';
const scope = 'read';
const mockResponse = {
headers: {
get: jest.fn(name => (name === 'WWW-Authenticate' ? `Basic realm="mcp", resource_metadata="${resourceUrl}"` : null))
get: jest.fn(name =>
name === 'WWW-Authenticate' ? `Bearer realm="mcp", resource_metadata="${resourceUrl}", scope="${scope}"` : null
)
}
} as unknown as Response;

expect(extractResourceMetadataUrl(mockResponse)).toBeUndefined();
expect(extractWWWAuthenticateParams(mockResponse)).toEqual({ scope: scope });
});
});

Expand Down
36 changes: 23 additions & 13 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -463,30 +463,40 @@ export async function selectResourceURL(
}

/**
* Extract resource_metadata from response header.
* Extract resource_metadata and scope from WWW-Authenticate header.
*/
export function extractResourceMetadataUrl(res: Response): URL | undefined {
export function extractWWWAuthenticateParams(res: Response): { resourceMetadataUrl?: URL; scope?: string } {
const authenticateHeader = res.headers.get('WWW-Authenticate');
if (!authenticateHeader) {
return undefined;
return {};
}

const [type, scheme] = authenticateHeader.split(' ');
if (type.toLowerCase() !== 'bearer' || !scheme) {
return undefined;
return {};
}
const regex = /resource_metadata="([^"]*)"/;
const match = regex.exec(authenticateHeader);

if (!match) {
return undefined;
}
const resourceMetadataRegex = /resource_metadata="([^"]*)"/;
const resourceMetadataMatch = resourceMetadataRegex.exec(authenticateHeader);

try {
return new URL(match[1]);
} catch {
return undefined;
const scopeRegex = /scope="([^"]*)"/;
const scopeMatch = scopeRegex.exec(authenticateHeader);

let resourceMetadataUrl: URL | undefined;
if (resourceMetadataMatch) {
try {
resourceMetadataUrl = new URL(resourceMetadataMatch[1]);
} catch {
// Ignore invalid URL
}
}

const scope = scopeMatch?.[1] || undefined;

return {
resourceMetadataUrl,
scope
};
}

/**
Expand Down
40 changes: 26 additions & 14 deletions src/client/middleware.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ jest.mock('../client/auth.js', () => {
return {
...actual,
auth: jest.fn(),
extractResourceMetadataUrl: jest.fn()
extractWWWAuthenticateParams: jest.fn()
};
});

import { auth, extractResourceMetadataUrl } from './auth.js';
import { auth, extractWWWAuthenticateParams } from './auth.js';

const mockAuth = auth as jest.MockedFunction<typeof auth>;
const mockExtractResourceMetadataUrl = extractResourceMetadataUrl as jest.MockedFunction<typeof extractResourceMetadataUrl>;
const mockExtractWWWAuthenticateParams = extractWWWAuthenticateParams as jest.MockedFunction<typeof extractWWWAuthenticateParams>;

describe('withOAuth', () => {
let mockProvider: jest.Mocked<OAuthClientProvider>;
Expand Down Expand Up @@ -129,8 +129,11 @@ describe('withOAuth', () => {

mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse);

const mockResourceUrl = new URL('https://oauth.example.com/.well-known/oauth-protected-resource');
mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl);
const mockWWWAuthenticateParams = {
resourceMetadataUrl: new URL('https://oauth.example.com/.well-known/oauth-protected-resource'),
scope: 'read'
};
mockExtractWWWAuthenticateParams.mockReturnValue(mockWWWAuthenticateParams);
mockAuth.mockResolvedValue('AUTHORIZED');

const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch);
Expand All @@ -141,7 +144,8 @@ describe('withOAuth', () => {
expect(mockFetch).toHaveBeenCalledTimes(2);
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://api.example.com',
resourceMetadataUrl: mockResourceUrl,
resourceMetadataUrl: mockWWWAuthenticateParams.resourceMetadataUrl,
scope: mockWWWAuthenticateParams.scope,
fetchFn: mockFetch
});

Expand Down Expand Up @@ -172,8 +176,11 @@ describe('withOAuth', () => {

mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse);

const mockResourceUrl = new URL('https://oauth.example.com/.well-known/oauth-protected-resource');
mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl);
const mockWWWAuthenticateParams = {
resourceMetadataUrl: new URL('https://oauth.example.com/.well-known/oauth-protected-resource'),
scope: 'read'
};
mockExtractWWWAuthenticateParams.mockReturnValue(mockWWWAuthenticateParams);
mockAuth.mockResolvedValue('AUTHORIZED');

// Test without baseUrl - should extract from request URL
Expand All @@ -185,7 +192,8 @@ describe('withOAuth', () => {
expect(mockFetch).toHaveBeenCalledTimes(2);
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://api.example.com', // Should be extracted from request URL
resourceMetadataUrl: mockResourceUrl,
resourceMetadataUrl: mockWWWAuthenticateParams.resourceMetadataUrl,
scope: mockWWWAuthenticateParams.scope,
fetchFn: mockFetch
});

Expand All @@ -203,7 +211,7 @@ describe('withOAuth', () => {
});

mockFetch.mockResolvedValue(new Response('Unauthorized', { status: 401 }));
mockExtractResourceMetadataUrl.mockReturnValue(undefined);
mockExtractWWWAuthenticateParams.mockReturnValue({});
mockAuth.mockResolvedValue('REDIRECT');

// Test without baseUrl
Expand All @@ -222,7 +230,7 @@ describe('withOAuth', () => {
});

mockFetch.mockResolvedValue(new Response('Unauthorized', { status: 401 }));
mockExtractResourceMetadataUrl.mockReturnValue(undefined);
mockExtractWWWAuthenticateParams.mockReturnValue({});
mockAuth.mockRejectedValue(new Error('Network error'));

const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch);
Expand All @@ -239,7 +247,7 @@ describe('withOAuth', () => {

// Always return 401
mockFetch.mockResolvedValue(new Response('Unauthorized', { status: 401 }));
mockExtractResourceMetadataUrl.mockReturnValue(undefined);
mockExtractWWWAuthenticateParams.mockReturnValue({});
mockAuth.mockResolvedValue('AUTHORIZED');

const enhancedFetch = withOAuth(mockProvider, 'https://api.example.com')(mockFetch);
Expand Down Expand Up @@ -345,7 +353,7 @@ describe('withOAuth', () => {

mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse);

mockExtractResourceMetadataUrl.mockReturnValue(undefined);
mockExtractWWWAuthenticateParams.mockReturnValue({});
mockAuth.mockResolvedValue('AUTHORIZED');

const enhancedFetch = withOAuth(mockProvider)(mockFetch);
Expand Down Expand Up @@ -876,7 +884,10 @@ describe('Integration Tests', () => {

mockFetch.mockResolvedValueOnce(unauthorizedResponse).mockResolvedValueOnce(successResponse);

mockExtractResourceMetadataUrl.mockReturnValue(new URL('https://auth.example.com/.well-known/oauth-protected-resource'));
mockExtractWWWAuthenticateParams.mockReturnValue({
resourceMetadataUrl: new URL('https://auth.example.com/.well-known/oauth-protected-resource'),
scope: 'read'
});
mockAuth.mockResolvedValue('AUTHORIZED');

// Use custom logger to avoid console output
Expand All @@ -896,6 +907,7 @@ describe('Integration Tests', () => {
expect(mockAuth).toHaveBeenCalledWith(mockProvider, {
serverUrl: 'https://mcp-server.example.com',
resourceMetadataUrl: new URL('https://auth.example.com/.well-known/oauth-protected-resource'),
scope: 'read',
fetchFn: mockFetch
});
});
Expand Down
5 changes: 3 additions & 2 deletions src/client/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { auth, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from './auth.js';
import { auth, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';
import { FetchLike } from '../shared/transport.js';

/**
Expand Down Expand Up @@ -54,14 +54,15 @@ export const withOAuth =
// Handle 401 responses by attempting re-authentication
if (response.status === 401) {
try {
const resourceMetadataUrl = extractResourceMetadataUrl(response);
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);

// Use provided baseUrl or extract from request URL
const serverUrl = baseUrl || (typeof input === 'string' ? new URL(input).origin : input.origin);

const result = await auth(provider, {
serverUrl,
resourceMetadataUrl,
scope,
fetchFn: next
});

Expand Down
15 changes: 12 additions & 3 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { EventSource, type ErrorEvent, type EventSourceInit } from 'eventsource';
import { Transport, FetchLike } from '../shared/transport.js';
import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from './auth.js';
import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';

export class SseError extends Error {
constructor(
Expand Down Expand Up @@ -64,6 +64,7 @@ export class SSEClientTransport implements Transport {
private _abortController?: AbortController;
private _url: URL;
private _resourceMetadataUrl?: URL;
private _scope?: string;
private _eventSourceInit?: EventSourceInit;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
Expand All @@ -77,6 +78,7 @@ export class SSEClientTransport implements Transport {
constructor(url: URL, opts?: SSEClientTransportOptions) {
this._url = url;
this._resourceMetadataUrl = undefined;
this._scope = undefined;
this._eventSourceInit = opts?.eventSourceInit;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
Expand All @@ -93,6 +95,7 @@ export class SSEClientTransport implements Transport {
result = await auth(this._authProvider, {
serverUrl: this._url,
resourceMetadataUrl: this._resourceMetadataUrl,
scope: this._scope,
fetchFn: this._fetch
});
} catch (error) {
Expand Down Expand Up @@ -136,7 +139,9 @@ export class SSEClientTransport implements Transport {
});

if (response.status === 401 && response.headers.has('www-authenticate')) {
this._resourceMetadataUrl = extractResourceMetadataUrl(response);
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);
this._resourceMetadataUrl = resourceMetadataUrl;
this._scope = scope;
}

return response;
Expand Down Expand Up @@ -213,6 +218,7 @@ export class SSEClientTransport implements Transport {
serverUrl: this._url,
authorizationCode,
resourceMetadataUrl: this._resourceMetadataUrl,
scope: this._scope,
fetchFn: this._fetch
});
if (result !== 'AUTHORIZED') {
Expand Down Expand Up @@ -245,11 +251,14 @@ export class SSEClientTransport implements Transport {
const response = await (this._fetch ?? fetch)(this._endpoint, init);
if (!response.ok) {
if (response.status === 401 && this._authProvider) {
this._resourceMetadataUrl = extractResourceMetadataUrl(response);
const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response);
this._resourceMetadataUrl = resourceMetadataUrl;
this._scope = scope;

const result = await auth(this._authProvider, {
serverUrl: this._url,
resourceMetadataUrl: this._resourceMetadataUrl,
scope: this._scope,
fetchFn: this._fetch
});
if (result !== 'AUTHORIZED') {
Expand Down
Loading
Loading