Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ describe("OAuth Authorization", () => {
expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path?param=value");
});

it("falls back to root discovery when path-aware discovery returns 404", async () => {
// First call (path-aware) returns 404
it("falls back to root discovery when path-aware discovery fails", async () => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could make this a parameterised test (with it.each) to check across many 4xx statuses?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also be good to add "falls back when path-aware discovery encounters internal server error" (to check the 500s are left out of the logic)

// First call (path-aware) returns 4xx
mockFetch.mockResolvedValueOnce({
ok: false,
status: 404,
status: 401,
});

// Second call (root fallback) succeeds
Expand Down Expand Up @@ -907,7 +907,7 @@ describe("OAuth Authorization", () => {
const metadata = await discoverAuthorizationServerMetadata("https://auth.example.com/tenant1");

expect(metadata).toBeUndefined();

// Verify that all discovery URLs were attempted
expect(mockFetch).toHaveBeenCalledTimes(8); // 4 URLs × 2 attempts each (with and without headers)
});
Expand Down
2 changes: 1 addition & 1 deletion src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ async function tryMetadataDiscovery(
* Determines if fallback to root discovery should be attempted
*/
function shouldAttemptFallback(response: Response | undefined, pathname: string): boolean {
return !response || response.status === 404 && pathname !== '/';
return !response || (response.status >= 400 && response.status < 500 ) && pathname !== '/';
}

/**
Expand Down
70 changes: 38 additions & 32 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ describe("StreamableHTTPClientTransport", () => {

// Verify custom fetch was used
expect(customFetch).toHaveBeenCalled();

// Global fetch should never have been called
expect(global.fetch).not.toHaveBeenCalled();
});
Expand Down Expand Up @@ -589,32 +589,32 @@ describe("StreamableHTTPClientTransport", () => {
await expect(transport.send(message)).rejects.toThrow(UnauthorizedError);
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
});

describe('Reconnection Logic', () => {
let transport: StreamableHTTPClientTransport;

// Use fake timers to control setTimeout and make the test instant.
beforeEach(() => jest.useFakeTimers());
afterEach(() => jest.useRealTimers());

it('should reconnect a GET-initiated notification stream that fails', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
}
});

const errorSpy = jest.fn();
transport.onerror = errorSpy;

const failingStream = new ReadableStream({
start(controller) { controller.error(new Error("Network failure")); }
});

const fetchMock = global.fetch as jest.Mock;
// Mock the initial GET request, which will fail.
fetchMock.mockResolvedValueOnce({
Expand All @@ -628,13 +628,13 @@ describe("StreamableHTTPClientTransport", () => {
headers: new Headers({ "content-type": "text/event-stream" }),
body: new ReadableStream(),
});

// ACT
await transport.start();
// Trigger the GET stream directly using the internal method for a clean test.
await transport["_startOrAuthSse"]({});
await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout

// ASSERT
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
Expand All @@ -644,47 +644,47 @@ describe("StreamableHTTPClientTransport", () => {
expect(fetchMock.mock.calls[0][1]?.method).toBe('GET');
expect(fetchMock.mock.calls[1][1]?.method).toBe('GET');
});

it('should NOT reconnect a POST-initiated stream that fails', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely
reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity
}
});

const errorSpy = jest.fn();
transport.onerror = errorSpy;

const failingStream = new ReadableStream({
start(controller) { controller.error(new Error("Network failure")); }
});

const fetchMock = global.fetch as jest.Mock;
// Mock the POST request. It returns a streaming content-type but a failing body.
fetchMock.mockResolvedValueOnce({
ok: true, status: 200,
headers: new Headers({ "content-type": "text/event-stream" }),
body: failingStream,
});

// A dummy request message to trigger the `send` logic.
const requestMessage: JSONRPCRequest = {
jsonrpc: '2.0',
method: 'long_running_tool',
id: 'request-1',
params: {},
};

// ACT
await transport.start();
// Use the public `send` method to initiate a POST that gets a stream response.
await transport.send(requestMessage);
await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections

// ASSERT
expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({
message: expect.stringContaining('SSE stream disconnected: Error: Network failure'),
Expand Down Expand Up @@ -718,7 +718,9 @@ describe("StreamableHTTPClientTransport", () => {
(global.fetch as jest.Mock)
// Initial connection
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery
// Resource discovery, path aware
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery, root
.mockResolvedValueOnce(unauthedResponse)
// OAuth metadata discovery
.mockResolvedValueOnce({
Expand Down Expand Up @@ -770,7 +772,9 @@ describe("StreamableHTTPClientTransport", () => {
(global.fetch as jest.Mock)
// Initial connection
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery
// Resource discovery, path aware
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery, root
.mockResolvedValueOnce(unauthedResponse)
// OAuth metadata discovery
.mockResolvedValueOnce({
Expand Down Expand Up @@ -822,7 +826,9 @@ describe("StreamableHTTPClientTransport", () => {
(global.fetch as jest.Mock)
// Initial connection
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery
// Resource discovery, path aware
.mockResolvedValueOnce(unauthedResponse)
// Resource discovery, root
.mockResolvedValueOnce(unauthedResponse)
// OAuth metadata discovery
.mockResolvedValueOnce({
Expand Down Expand Up @@ -888,7 +894,7 @@ describe("StreamableHTTPClientTransport", () => {
ok: false,
status: 404
});

// Create transport instance
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
authProvider: mockAuthProvider,
Expand All @@ -901,14 +907,14 @@ describe("StreamableHTTPClientTransport", () => {

// Verify custom fetch was used
expect(customFetch).toHaveBeenCalled();

// Verify specific OAuth endpoints were called with custom fetch
const customFetchCalls = customFetch.mock.calls;
const callUrls = customFetchCalls.map(([url]) => url.toString());

// Should have called resource metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);

// Should have called OAuth authorization server metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);

Expand Down Expand Up @@ -966,19 +972,19 @@ describe("StreamableHTTPClientTransport", () => {

// Verify custom fetch was used
expect(customFetch).toHaveBeenCalled();

// Verify specific OAuth endpoints were called with custom fetch
const customFetchCalls = customFetch.mock.calls;
const callUrls = customFetchCalls.map(([url]) => url.toString());

// Should have called resource metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true);

// Should have called OAuth authorization server metadata discovery
expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true);

// Should have called token endpoint for authorization code exchange
const tokenCalls = customFetchCalls.filter(([url, options]) =>
const tokenCalls = customFetchCalls.filter(([url, options]) =>
url.toString().includes('/token') && options?.method === "POST"
);
expect(tokenCalls.length).toBeGreaterThan(0);
Expand Down