diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 4fce9976..0439518a 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -295,15 +295,34 @@ describe("SSEClientTransport", () => { expect(lastServerRequest.headers.authorization).toBe(authToken); }); - it("passes custom headers to fetch requests", async () => { - const customHeaders = { - Authorization: "Bearer test-token", - "X-Custom-Header": "custom-value", - }; - + it.each([ + { + description: "plain object headers", + headers: { + Authorization: "Bearer test-token", + "X-Custom-Header": "custom-value", + }, + }, + { + description: "Headers object", + headers: ((): HeadersInit => { + const h = new Headers(); + h.set("Authorization", "Bearer test-token"); + h.set("X-Custom-Header", "custom-value"); + return h; + })(), + }, + { + description: "array of tuples", + headers: ((): HeadersInit => ([ + ["Authorization", "Bearer test-token"], + ["X-Custom-Header", "custom-value"], + ]))(), + }, + ])("passes custom headers to fetch requests ($description)", async ({ headers }) => { transport = new SSEClientTransport(resourceBaseUrl, { requestInit: { - headers: customHeaders, + headers, }, }); @@ -337,12 +356,8 @@ describe("SSEClientTransport", () => { const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] .headers; - expect(calledHeaders.get("Authorization")).toBe( - customHeaders.Authorization, - ); - expect(calledHeaders.get("X-Custom-Header")).toBe( - customHeaders["X-Custom-Header"], - ); + expect(calledHeaders.get("Authorization")).toBe("Bearer test-token"); + expect(calledHeaders.get("X-Custom-Header")).toBe("custom-value"); expect(calledHeaders.get("content-type")).toBe("application/json"); } finally { // Restore original fetch diff --git a/src/client/sse.ts b/src/client/sse.ts index e1c86ccd..b7a3a91d 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -2,6 +2,7 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource" import { Transport, FetchLike } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; +import { normalizeHeaders } from "../shared/headers.js"; export class SseError extends Error { constructor( @@ -107,7 +108,7 @@ export class SSEClientTransport implements Transport { } private async _commonHeaders(): Promise { - const headers: HeadersInit = {}; + const headers: HeadersInit & Record = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { @@ -118,9 +119,12 @@ export class SSEClientTransport implements Transport { headers["mcp-protocol-version"] = this._protocolVersion; } - return new Headers( - { ...headers, ...this._requestInit?.headers } - ); + const extraHeaders = normalizeHeaders(this._requestInit?.headers); + + return new Headers({ + ...headers, + ...extraHeaders, + }); } private _startOrAuth(): Promise { diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 12714ea4..561cef49 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -2,6 +2,7 @@ import { Transport, FetchLike } from "../shared/transport.js"; import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { EventSourceParserStream } from "eventsource-parser/stream"; +import { normalizeHeaders } from "../shared/headers.js"; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -185,7 +186,7 @@ export class StreamableHTTPClientTransport implements Transport { headers["mcp-protocol-version"] = this._protocolVersion; } - const extraHeaders = this._normalizeHeaders(this._requestInit?.headers); + const extraHeaders = normalizeHeaders(this._requestInit?.headers); return new Headers({ ...headers, @@ -256,20 +257,6 @@ export class StreamableHTTPClientTransport implements Transport { } - private _normalizeHeaders(headers: HeadersInit | undefined): Record { - if (!headers) return {}; - - if (headers instanceof Headers) { - return Object.fromEntries(headers.entries()); - } - - if (Array.isArray(headers)) { - return Object.fromEntries(headers); - } - - return { ...headers as Record }; - } - /** * Schedule a reconnection attempt with exponential backoff * diff --git a/src/shared/headers.ts b/src/shared/headers.ts new file mode 100644 index 00000000..ed0df02d --- /dev/null +++ b/src/shared/headers.ts @@ -0,0 +1,15 @@ +export function normalizeHeaders( + headers: HeadersInit | undefined +): Record { + if (!headers) return {}; + + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } + + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } + + return { ...headers }; +}