Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ describe("OAuth Authorization", () => {
expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path?param=value");
});

it("falls back to root discovery when path-aware discovery returns 404", async () => {
// First call (path-aware) returns 404
it.each([400, 401, 403, 404, 410, 422, 429])("falls back to root discovery when path-aware discovery returns %d", async (statusCode) => {
// First call (path-aware) returns 4xx
mockFetch.mockResolvedValueOnce({
ok: false,
status: 404,
status: statusCode,
});

// Second call (root fallback) succeeds
Expand Down Expand Up @@ -267,6 +267,20 @@ describe("OAuth Authorization", () => {
expect(calls.length).toBe(2);
});

it("throws error on 500 status and does not fallback", async () => {
// First call (path-aware) returns 500
mockFetch.mockResolvedValueOnce({
ok: false,
status: 500,
});

await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name"))
.rejects.toThrow();

const calls = mockFetch.mock.calls;
expect(calls.length).toBe(1); // Should not attempt fallback
});

it("does not fallback when the original URL is already at root path", async () => {
// First call (path-aware for root) returns 404
mockFetch.mockResolvedValueOnce({
Expand Down Expand Up @@ -907,7 +921,7 @@ describe("OAuth Authorization", () => {
const metadata = await discoverAuthorizationServerMetadata("https://auth.example.com/tenant1");

expect(metadata).toBeUndefined();

// Verify that all discovery URLs were attempted
expect(mockFetch).toHaveBeenCalledTimes(8); // 4 URLs × 2 attempts each (with and without headers)
});
Expand Down
2 changes: 1 addition & 1 deletion src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ async function tryMetadataDiscovery(
* Determines if fallback to root discovery should be attempted
*/
function shouldAttemptFallback(response: Response | undefined, pathname: string): boolean {
return !response || response.status === 404 && pathname !== '/';
return !response || (response.status >= 400 && response.status < 500) && pathname !== '/';
}

/**
Expand Down
70 changes: 38 additions & 32 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ describe("StreamableHTTPClientTransport", () => {

// Verify custom fetch was used
expect(customFetch).toHaveBeenCalled();

// Global fetch should never have been called
expect(global.fetch).not.toHaveBeenCalled();
});
Expand Down Expand Up @@ -589,32 +589,32 @@ describe("StreamableHTTPClientTransport", () => {
await expect(transport.send(message)).rejects.toThrow(UnauthorizedError);
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
});

describe('Reconnection Logic', () => {
let transport: StreamableHTTPClientTransport;

// Use fake timers to control setTimeout and make the test instant.
beforeEach(() => jest.useFakeTimers());
afterEach(() => jest.useRealTimers());

it('should reconnect a GET-initiated notification stream that fails', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
}
});

const errorSpy = jest.fn();
transport.onerror = errorSpy;

const failingStream = new ReadableStream({
start(controller) { controller.error(new Error("Network failure")); }
});

const fetchMock = global.fetch as jest.Mock;
// Mock the initial GET request, which will fail.
fetchMock.mockResolvedValueOnce({
Expand All @@ -628,13 +628,13 @@ describe("StreamableHTTPClientTransport", () => {
headers: new Headers({ "content-type": "text/event-stream" }),
body: new ReadableStream(),
});

// ACT
await transport.start();
// Trigger the GET stream directly using the internal method for a clean test.
await transport["_startOrAuthSse"]({});
await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout

// ASSERT
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
Expand All @@ -644,47 +644,47 @@ describe("StreamableHTTPClientTransport", () => {
expect(fetchMock.mock.calls[0][1]?.method).toBe('GET');
expect(fetchMock.mock.calls[1][1]?.method).toBe('GET');
});

it('should NOT reconnect a POST-initiated stream that fails', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
}
});

const errorSpy = jest.fn();
transport.onerror = errorSpy;

const failingStream = new ReadableStream({
start(controller) { controller.error(new Error("Network failure")); }
});

const fetchMock = global.fetch as jest.Mock;
// Mock the POST request. It returns a streaming content-type but a failing body.
fetchMock.mockResolvedValueOnce({
ok: true, status: 200,
headers: new Headers({ "content-type": "text/event-stream" }),
body: failingStream,
});

// A dummy request message to trigger the `send` logic.
const requestMessage: JSONRPCRequest = {
jsonrpc: '2.0',
method: 'long_running_tool',
id: 'request-1',
params: {},
};

// ACT
await transport.start();
// Use the public `send` method to initiate a POST that gets a stream response.
await transport.send(requestMessage);
await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections

// ASSERT
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
Expand Down Expand Up @@ -718,7 +718,9 @@ describe("StreamableHTTPClientTransport", () => {
(global.fetch as jest.Mock)
// Initial connection
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery
// Resource discovery, path aware
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery, root
.mockResolvedValueOnce(unauthedResponse)
// OAuth metadata discovery
.mockResolvedValueOnce({
Expand Down Expand Up @@ -770,7 +772,9 @@ describe("StreamableHTTPClientTransport", () => {
(global.fetch as jest.Mock)
// Initial connection
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery
// Resource discovery, path aware
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery, root
.mockResolvedValueOnce(unauthedResponse)
// OAuth metadata discovery
.mockResolvedValueOnce({
Expand Down Expand Up @@ -822,7 +826,9 @@ describe("StreamableHTTPClientTransport", () => {
(global.fetch as jest.Mock)
// Initial connection
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery
// Resource discovery, path aware
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery, root
.mockResolvedValueOnce(unauthedResponse)
// OAuth metadata discovery
.mockResolvedValueOnce({
Expand Down Expand Up @@ -888,7 +894,7 @@ describe("StreamableHTTPClientTransport", () => {
ok: false,
status: 404
});

// Create transport instance
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
authProvider: mockAuthProvider,
Expand All @@ -901,14 +907,14 @@ describe("StreamableHTTPClientTransport", () => {

// Verify custom fetch was used
expect(customFetch).toHaveBeenCalled();

// Verify specific OAuth endpoints were called with custom fetch
const customFetchCalls = customFetch.mock.calls;
const callUrls = customFetchCalls.map(([url]) => url.toString());

// Should have called resource metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);

// Should have called OAuth authorization server metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);

Expand Down Expand Up @@ -966,19 +972,19 @@ describe("StreamableHTTPClientTransport", () => {

// Verify custom fetch was used
expect(customFetch).toHaveBeenCalled();

// Verify specific OAuth endpoints were called with custom fetch
const customFetchCalls = customFetch.mock.calls;
const callUrls = customFetchCalls.map(([url]) => url.toString());

// Should have called resource metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);

// Should have called OAuth authorization server metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);

// Should have called token endpoint for authorization code exchange
const tokenCalls = customFetchCalls.filter(([url, options]) =>
const tokenCalls = customFetchCalls.filter(([url, options]) =>
url.toString().includes('/token') && options?.method === "POST"
);
expect(tokenCalls.length).toBeGreaterThan(0);
Expand Down