Skip to content

Commit e9caa5a

Browse files
gylove1994ihrpr
authored andcommitted
✨ feat: add customHeaders options
1 parent e3a6109 commit e9caa5a

File tree

2 files changed

+155
-7
lines changed

2 files changed

+155
-7
lines changed

src/server/streamable-http.test.ts

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ describe("StreamableHTTPServerTransport", () => {
9898

9999
await transport.handleRequest(req, mockResponse);
100100

101-
expect(mockResponse.writeHead).toHaveBeenCalledWith(404);
101+
expect(mockResponse.writeHead).toHaveBeenCalledWith(404, {});
102102
// check if the error response is a valid JSON-RPC error format
103103
expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"'));
104104
expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"'));
@@ -115,7 +115,7 @@ describe("StreamableHTTPServerTransport", () => {
115115

116116
await transport.handleRequest(req, mockResponse);
117117

118-
expect(mockResponse.writeHead).toHaveBeenCalledWith(400);
118+
expect(mockResponse.writeHead).toHaveBeenCalledWith(400, {});
119119
expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"'));
120120
expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"'));
121121
});
@@ -342,7 +342,7 @@ describe("StreamableHTTPServerTransport", () => {
342342

343343
await transport.handleRequest(req, mockResponse);
344344

345-
expect(mockResponse.writeHead).toHaveBeenCalledWith(406);
345+
expect(mockResponse.writeHead).toHaveBeenCalledWith(406, {});
346346
expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"'));
347347
});
348348

@@ -788,4 +788,141 @@ describe("StreamableHTTPServerTransport", () => {
788788
expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage);
789789
});
790790
});
791+
792+
describe("Custom Headers", () => {
793+
const customHeaders = {
794+
"X-Custom-Header": "custom-value",
795+
"X-API-Version": "1.0",
796+
"Access-Control-Allow-Origin": "*"
797+
};
798+
799+
let transportWithHeaders: StreamableHTTPServerTransport;
800+
let mockResponse: jest.Mocked<ServerResponse>;
801+
802+
beforeEach(() => {
803+
transportWithHeaders = new StreamableHTTPServerTransport(endpoint, { customHeaders });
804+
mockResponse = createMockResponse();
805+
});
806+
807+
it("should include custom headers in SSE response", async () => {
808+
const req = createMockRequest({
809+
method: "GET",
810+
headers: {
811+
accept: "text/event-stream",
812+
"mcp-session-id": transportWithHeaders.sessionId
813+
},
814+
});
815+
816+
await transportWithHeaders.handleRequest(req, mockResponse);
817+
818+
expect(mockResponse.writeHead).toHaveBeenCalledWith(
819+
200,
820+
expect.objectContaining({
821+
...customHeaders,
822+
"Content-Type": "text/event-stream",
823+
"Cache-Control": "no-cache",
824+
"Connection": "keep-alive",
825+
"mcp-session-id": transportWithHeaders.sessionId
826+
})
827+
);
828+
});
829+
830+
it("should include custom headers in JSON response", async () => {
831+
const message: JSONRPCMessage = {
832+
jsonrpc: "2.0",
833+
method: "test",
834+
params: {},
835+
id: 1,
836+
};
837+
838+
const req = createMockRequest({
839+
method: "POST",
840+
headers: {
841+
"content-type": "application/json",
842+
"accept": "application/json",
843+
"mcp-session-id": transportWithHeaders.sessionId
844+
},
845+
body: JSON.stringify(message),
846+
});
847+
848+
await transportWithHeaders.handleRequest(req, mockResponse);
849+
850+
expect(mockResponse.writeHead).toHaveBeenCalledWith(
851+
200,
852+
expect.objectContaining({
853+
...customHeaders,
854+
"Content-Type": "application/json",
855+
"mcp-session-id": transportWithHeaders.sessionId
856+
})
857+
);
858+
});
859+
860+
it("should include custom headers in error responses", async () => {
861+
const req = createMockRequest({
862+
method: "GET",
863+
headers: {
864+
accept: "text/event-stream",
865+
"mcp-session-id": "invalid-session-id"
866+
},
867+
});
868+
869+
await transportWithHeaders.handleRequest(req, mockResponse);
870+
871+
expect(mockResponse.writeHead).toHaveBeenCalledWith(
872+
404,
873+
expect.objectContaining(customHeaders)
874+
);
875+
});
876+
877+
it("should not override essential headers with custom headers", async () => {
878+
const transportWithConflictingHeaders = new StreamableHTTPServerTransport(endpoint, {
879+
customHeaders: {
880+
"Content-Type": "text/plain", // 尝试覆盖必要的 Content-Type 头
881+
"X-Custom-Header": "custom-value"
882+
}
883+
});
884+
885+
const req = createMockRequest({
886+
method: "GET",
887+
headers: {
888+
accept: "text/event-stream",
889+
"mcp-session-id": transportWithConflictingHeaders.sessionId
890+
},
891+
});
892+
893+
await transportWithConflictingHeaders.handleRequest(req, mockResponse);
894+
895+
expect(mockResponse.writeHead).toHaveBeenCalledWith(
896+
200,
897+
expect.objectContaining({
898+
"Content-Type": "text/event-stream", // 应该保持原有的 Content-Type
899+
"X-Custom-Header": "custom-value"
900+
})
901+
);
902+
});
903+
904+
it("should work with empty custom headers", async () => {
905+
const transportWithoutHeaders = new StreamableHTTPServerTransport(endpoint);
906+
907+
const req = createMockRequest({
908+
method: "GET",
909+
headers: {
910+
accept: "text/event-stream",
911+
"mcp-session-id": transportWithoutHeaders.sessionId
912+
},
913+
});
914+
915+
await transportWithoutHeaders.handleRequest(req, mockResponse);
916+
917+
expect(mockResponse.writeHead).toHaveBeenCalledWith(
918+
200,
919+
expect.objectContaining({
920+
"Content-Type": "text/event-stream",
921+
"Cache-Control": "no-cache",
922+
"Connection": "keep-alive",
923+
"mcp-session-id": transportWithoutHeaders.sessionId
924+
})
925+
);
926+
});
927+
});
791928
});

src/server/streamable-http.ts

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ export interface StreamableHTTPServerTransportOptions {
2828
* @default true
2929
*/
3030
enableSessionManagement?: boolean;
31+
32+
/**
33+
* Custom headers to be included in all responses
34+
* These headers will be added to both SSE and regular HTTP responses
35+
*/
36+
customHeaders?: Record<string, string>;
3137
}
3238

3339
/**
@@ -72,6 +78,7 @@ export class StreamableHTTPServerTransport implements Transport {
7278
private _started: boolean = false;
7379
private _requestConnections: Map<string, string> = new Map(); // request ID to connection ID mapping
7480
private _enableSessionManagement: boolean;
81+
private _customHeaders: Record<string, string>;
7582

7683
onclose?: () => void;
7784
onerror?: (error: Error) => void;
@@ -80,6 +87,7 @@ export class StreamableHTTPServerTransport implements Transport {
8087
constructor(private _endpoint: string, options?: StreamableHTTPServerTransportOptions) {
8188
this._sessionId = randomUUID();
8289
this._enableSessionManagement = options?.enableSessionManagement !== false;
90+
this._customHeaders = options?.customHeaders || {};
8391
}
8492

8593
/**
@@ -111,7 +119,7 @@ export class StreamableHTTPServerTransport implements Transport {
111119
// Continue processing normally
112120
} else if (!sessionId) {
113121
// Non-initialization requests without a session ID should return 400 Bad Request
114-
res.writeHead(400).end(JSON.stringify({
122+
res.writeHead(400, this._customHeaders).end(JSON.stringify({
115123
jsonrpc: "2.0",
116124
error: {
117125
code: -32000,
@@ -122,7 +130,7 @@ export class StreamableHTTPServerTransport implements Transport {
122130
return;
123131
} else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) {
124132
// Reject requests with invalid session ID with 404 Not Found
125-
res.writeHead(404).end(JSON.stringify({
133+
res.writeHead(404, this._customHeaders).end(JSON.stringify({
126134
jsonrpc: "2.0",
127135
error: {
128136
code: -32001,
@@ -141,7 +149,7 @@ export class StreamableHTTPServerTransport implements Transport {
141149
} else if (req.method === "DELETE") {
142150
await this.handleDeleteRequest(req, res);
143151
} else {
144-
res.writeHead(405).end(JSON.stringify({
152+
res.writeHead(405, this._customHeaders).end(JSON.stringify({
145153
jsonrpc: "2.0",
146154
error: {
147155
code: -32000,
@@ -159,7 +167,7 @@ export class StreamableHTTPServerTransport implements Transport {
159167
// validate the Accept header
160168
const acceptHeader = req.headers.accept;
161169
if (!acceptHeader || !acceptHeader.includes("text/event-stream")) {
162-
res.writeHead(406).end(JSON.stringify({
170+
res.writeHead(406, this._customHeaders).end(JSON.stringify({
163171
jsonrpc: "2.0",
164172
error: {
165173
code: -32000,
@@ -176,6 +184,7 @@ export class StreamableHTTPServerTransport implements Transport {
176184

177185
// Prepare response headers
178186
const headers: Record<string, string> = {
187+
...this._customHeaders,
179188
"Content-Type": "text/event-stream",
180189
"Cache-Control": "no-cache",
181190
Connection: "keep-alive",
@@ -294,6 +303,7 @@ export class StreamableHTTPServerTransport implements Transport {
294303

295304
if (useSSE) {
296305
const headers: Record<string, string> = {
306+
...this._customHeaders,
297307
"Content-Type": "text/event-stream",
298308
"Cache-Control": "no-cache",
299309
Connection: "keep-alive",
@@ -338,6 +348,7 @@ export class StreamableHTTPServerTransport implements Transport {
338348
} else {
339349
// use direct JSON response
340350
const headers: Record<string, string> = {
351+
...this._customHeaders,
341352
"Content-Type": "application/json",
342353
};
343354

0 commit comments

Comments
 (0)