Skip to content

Commit e5167cb

Browse files
committed
Update streamableHttp custom fetch test
1 parent 1b14bd7 commit e5167cb

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

src/client/sse.test.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,36 @@ describe("SSEClientTransport", () => {
262262
expect(lastServerRequest.headers.authorization).toBe(authToken);
263263
});
264264

265+
it("uses custom fetch implementation from options", async () => {
266+
const authToken = "Bearer custom-token";
267+
268+
const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
269+
const headers = new Headers(init?.headers);
270+
headers.set("Authorization", authToken);
271+
return fetch(url.toString(), { ...init, headers });
272+
});
273+
274+
transport = new SSEClientTransport(resourceBaseUrl, {
275+
fetch: fetchWithAuth,
276+
});
277+
278+
await transport.start();
279+
280+
expect(lastServerRequest.headers.authorization).toBe(authToken);
281+
282+
// Send a message to verify fetchWithAuth used for POST as well
283+
const message: JSONRPCMessage = {
284+
jsonrpc: "2.0",
285+
id: "1",
286+
method: "test",
287+
params: {},
288+
};
289+
290+
await transport.send(message);
291+
292+
expect(fetchWithAuth).toHaveBeenCalled();
293+
});
294+
265295
it("passes custom headers to fetch requests", async () => {
266296
const customHeaders = {
267297
Authorization: "Bearer test-token",

src/client/sse.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { Transport } from "../shared/transport.js";
33
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
44
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
55

6+
export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
7+
68
export class SseError extends Error {
79
constructor(
810
public readonly code: number | undefined,
@@ -47,6 +49,11 @@ export type SSEClientTransportOptions = {
4749
* Customizes recurring POST requests to the server.
4850
*/
4951
requestInit?: RequestInit;
52+
53+
/**
54+
* Custom fetch implementation used for all network requests.
55+
*/
56+
fetch?: FetchLike;
5057
};
5158

5259
/**
@@ -62,6 +69,7 @@ export class SSEClientTransport implements Transport {
6269
private _eventSourceInit?: EventSourceInit;
6370
private _requestInit?: RequestInit;
6471
private _authProvider?: OAuthClientProvider;
72+
private _fetch?: FetchLike;
6573
private _protocolVersion?: string;
6674

6775
onclose?: () => void;
@@ -77,6 +85,7 @@ export class SSEClientTransport implements Transport {
7785
this._eventSourceInit = opts?.eventSourceInit;
7886
this._requestInit = opts?.requestInit;
7987
this._authProvider = opts?.authProvider;
88+
this._fetch = opts?.fetch;
8089
}
8190

8291
private async _authThenStart(): Promise<void> {
@@ -117,7 +126,7 @@ export class SSEClientTransport implements Transport {
117126
}
118127

119128
private _startOrAuth(): Promise<void> {
120-
const fetchImpl = (this?._eventSourceInit?.fetch || fetch) as typeof fetch
129+
const fetchImpl = (this?._eventSourceInit?.fetch || this._fetch || fetch) as typeof fetch
121130
return new Promise((resolve, reject) => {
122131
this._eventSource = new EventSource(
123132
this._url.href,
@@ -242,7 +251,7 @@ export class SSEClientTransport implements Transport {
242251
signal: this._abortController?.signal,
243252
};
244253

245-
const response = await fetch(this._endpoint, init);
254+
const response = await (this._fetch ?? fetch)(this._endpoint, init);
246255
if (!response.ok) {
247256
if (response.status === 401 && this._authProvider) {
248257

src/client/streamableHttp.test.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,35 @@ describe("StreamableHTTPClientTransport", () => {
443443
expect(errorSpy).toHaveBeenCalled();
444444
});
445445

446+
it("uses custom fetch implementation", async () => {
447+
const authToken = "Bearer custom-token";
448+
449+
const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
450+
const headers = new Headers(init?.headers);
451+
headers.set("Authorization", authToken);
452+
return (global.fetch as jest.Mock)(url, { ...init, headers });
453+
});
454+
455+
(global.fetch as jest.Mock)
456+
.mockResolvedValueOnce(
457+
new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } })
458+
)
459+
.mockResolvedValueOnce(new Response(null, { status: 202 }));
460+
461+
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { fetch: fetchWithAuth });
462+
463+
await transport.start();
464+
await (transport as unknown as { _startOrAuthSse: (opts: any) => Promise<void> })._startOrAuthSse({});
465+
466+
await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage);
467+
468+
expect(fetchWithAuth).toHaveBeenCalled();
469+
for (const call of (global.fetch as jest.Mock).mock.calls) {
470+
const headers = call[1].headers as Headers;
471+
expect(headers.get("Authorization")).toBe(authToken);
472+
}
473+
});
474+
446475

447476
it("should always send specified custom headers", async () => {
448477
const requestInit = {

src/client/streamableHttp.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPC
33
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
44
import { EventSourceParserStream } from "eventsource-parser/stream";
55

6+
export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
7+
68
// Default reconnection options for StreamableHTTP connections
79
const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = {
810
initialReconnectionDelay: 1000,
@@ -99,6 +101,11 @@ export type StreamableHTTPClientTransportOptions = {
99101
*/
100102
requestInit?: RequestInit;
101103

104+
/**
105+
* Custom fetch implementation used for all network requests.
106+
*/
107+
fetch?: FetchLike;
108+
102109
/**
103110
* Options to configure the reconnection behavior.
104111
*/
@@ -122,6 +129,7 @@ export class StreamableHTTPClientTransport implements Transport {
122129
private _resourceMetadataUrl?: URL;
123130
private _requestInit?: RequestInit;
124131
private _authProvider?: OAuthClientProvider;
132+
private _fetch?: FetchLike;
125133
private _sessionId?: string;
126134
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
127135
private _protocolVersion?: string;
@@ -138,6 +146,7 @@ export class StreamableHTTPClientTransport implements Transport {
138146
this._resourceMetadataUrl = undefined;
139147
this._requestInit = opts?.requestInit;
140148
this._authProvider = opts?.authProvider;
149+
this._fetch = opts?.fetch;
141150
this._sessionId = opts?.sessionId;
142151
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
143152
}
@@ -200,7 +209,7 @@ export class StreamableHTTPClientTransport implements Transport {
200209
headers.set("last-event-id", resumptionToken);
201210
}
202211

203-
const response = await fetch(this._url, {
212+
const response = await (this._fetch ?? fetch)(this._url, {
204213
method: "GET",
205214
headers,
206215
signal: this._abortController?.signal,
@@ -414,7 +423,7 @@ export class StreamableHTTPClientTransport implements Transport {
414423
signal: this._abortController?.signal,
415424
};
416425

417-
const response = await fetch(this._url, init);
426+
const response = await (this._fetch ?? fetch)(this._url, init);
418427

419428
// Handle session ID received during initialization
420429
const sessionId = response.headers.get("mcp-session-id");
@@ -520,7 +529,7 @@ export class StreamableHTTPClientTransport implements Transport {
520529
signal: this._abortController?.signal,
521530
};
522531

523-
const response = await fetch(this._url, init);
532+
const response = await (this._fetch ?? fetch)(this._url, init);
524533

525534
// We specifically handle 405 as a valid response according to the spec,
526535
// meaning the server does not support explicit session termination

0 commit comments

Comments
 (0)