Skip to content

Commit 5d58116

Browse files
committed
fix: normalize headers in sse transport
1 parent a1608a6 commit 5d58116

File tree

4 files changed

+53
-32
lines changed

4 files changed

+53
-32
lines changed

src/client/sse.test.ts

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,15 +295,34 @@ describe("SSEClientTransport", () => {
295295
expect(lastServerRequest.headers.authorization).toBe(authToken);
296296
});
297297

298-
it("passes custom headers to fetch requests", async () => {
299-
const customHeaders = {
300-
Authorization: "Bearer test-token",
301-
"X-Custom-Header": "custom-value",
302-
};
303-
298+
it.each([
299+
{
300+
description: "plain object headers",
301+
headers: {
302+
Authorization: "Bearer test-token",
303+
"X-Custom-Header": "custom-value",
304+
},
305+
},
306+
{
307+
description: "Headers object",
308+
headers: ((): HeadersInit => {
309+
const h = new Headers();
310+
h.set("Authorization", "Bearer test-token");
311+
h.set("X-Custom-Header", "custom-value");
312+
return h;
313+
})(),
314+
},
315+
{
316+
description: "array of tuples",
317+
headers: ((): HeadersInit => ([
318+
["Authorization", "Bearer test-token"],
319+
["X-Custom-Header", "custom-value"],
320+
]))(),
321+
},
322+
])("passes custom headers to fetch requests ($description)", async ({ headers }) => {
304323
transport = new SSEClientTransport(resourceBaseUrl, {
305324
requestInit: {
306-
headers: customHeaders,
325+
headers,
307326
},
308327
});
309328

@@ -337,12 +356,8 @@ describe("SSEClientTransport", () => {
337356

338357
const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1]
339358
.headers;
340-
expect(calledHeaders.get("Authorization")).toBe(
341-
customHeaders.Authorization,
342-
);
343-
expect(calledHeaders.get("X-Custom-Header")).toBe(
344-
customHeaders["X-Custom-Header"],
345-
);
359+
expect(calledHeaders.get("Authorization")).toBe("Bearer test-token");
360+
expect(calledHeaders.get("X-Custom-Header")).toBe("custom-value");
346361
expect(calledHeaders.get("content-type")).toBe("application/json");
347362
} finally {
348363
// Restore original fetch

src/client/sse.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"
22
import { Transport, FetchLike } from "../shared/transport.js";
33
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
44
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
5+
import { normalizeHeaders } from "../shared/headers.js";
56

67
export class SseError extends Error {
78
constructor(
@@ -107,7 +108,7 @@ export class SSEClientTransport implements Transport {
107108
}
108109

109110
private async _commonHeaders(): Promise<Headers> {
110-
const headers: HeadersInit = {};
111+
const headers: HeadersInit & Record<string, string> = {};
111112
if (this._authProvider) {
112113
const tokens = await this._authProvider.tokens();
113114
if (tokens) {
@@ -118,9 +119,12 @@ export class SSEClientTransport implements Transport {
118119
headers["mcp-protocol-version"] = this._protocolVersion;
119120
}
120121

121-
return new Headers(
122-
{ ...headers, ...this._requestInit?.headers }
123-
);
122+
const extraHeaders = normalizeHeaders(this._requestInit?.headers);
123+
124+
return new Headers({
125+
...headers,
126+
...extraHeaders,
127+
});
124128
}
125129

126130
private _startOrAuth(): Promise<void> {

src/client/streamableHttp.ts

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { Transport, FetchLike } from "../shared/transport.js";
22
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
33
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
44
import { EventSourceParserStream } from "eventsource-parser/stream";
5+
import { normalizeHeaders } from "../shared/headers.js";
56

67
// Default reconnection options for StreamableHTTP connections
78
const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = {
@@ -185,7 +186,7 @@ export class StreamableHTTPClientTransport implements Transport {
185186
headers["mcp-protocol-version"] = this._protocolVersion;
186187
}
187188

188-
const extraHeaders = this._normalizeHeaders(this._requestInit?.headers);
189+
const extraHeaders = normalizeHeaders(this._requestInit?.headers);
189190

190191
return new Headers({
191192
...headers,
@@ -256,20 +257,6 @@ export class StreamableHTTPClientTransport implements Transport {
256257

257258
}
258259

259-
private _normalizeHeaders(headers: HeadersInit | undefined): Record<string, string> {
260-
if (!headers) return {};
261-
262-
if (headers instanceof Headers) {
263-
return Object.fromEntries(headers.entries());
264-
}
265-
266-
if (Array.isArray(headers)) {
267-
return Object.fromEntries(headers);
268-
}
269-
270-
return { ...headers as Record<string, string> };
271-
}
272-
273260
/**
274261
* Schedule a reconnection attempt with exponential backoff
275262
*

src/shared/headers.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export function normalizeHeaders(
2+
headers: HeadersInit | undefined
3+
): Record<string, string> {
4+
if (!headers) return {};
5+
6+
if (headers instanceof Headers) {
7+
return Object.fromEntries(headers.entries());
8+
}
9+
10+
if (Array.isArray(headers)) {
11+
return Object.fromEntries(headers);
12+
}
13+
14+
return { ...headers };
15+
}

0 commit comments

Comments
 (0)