Skip to content

Commit adbacc6

Browse files
committed
Add DNS rebinding protection for SSE transport
1 parent 590d484 commit adbacc6

File tree

3 files changed

+306
-4
lines changed

3 files changed

+306
-4
lines changed

package-lock.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/server/sse.test.ts

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
import http from 'http';
22
import { jest } from '@jest/globals';
3-
import { SSEServerTransport } from './sse.js';
3+
import { SSEServerTransport } from './sse.js';
4+
import { AuthInfo } from './auth/types.js';
45

56
const createMockResponse = () => {
67
const res = {
78
writeHead: jest.fn<http.ServerResponse['writeHead']>(),
89
write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true),
910
on: jest.fn<http.ServerResponse['on']>(),
11+
end: jest.fn<http.ServerResponse['end']>().mockReturnThis(),
1012
};
1113
res.writeHead.mockReturnThis();
1214
res.on.mockReturnThis();
1315

1416
return res as unknown as http.ServerResponse;
1517
};
1618

19+
const createMockRequest = (headers: Record<string, string> = {}) => {
20+
return {
21+
headers,
22+
} as unknown as http.IncomingMessage & { auth?: AuthInfo };
23+
};
24+
1725
describe('SSEServerTransport', () => {
1826
describe('start method', () => {
1927
it('should correctly append sessionId to a simple relative endpoint', async () => {
@@ -106,4 +114,234 @@ describe('SSEServerTransport', () => {
106114
);
107115
});
108116
});
117+
118+
describe('DNS rebinding protection', () => {
119+
beforeEach(() => {
120+
jest.clearAllMocks();
121+
});
122+
123+
describe('Host header validation', () => {
124+
it('should accept requests with allowed host headers', async () => {
125+
const mockRes = createMockResponse();
126+
const transport = new SSEServerTransport('/messages', mockRes, {
127+
allowedHosts: ['localhost:3000', 'example.com'],
128+
});
129+
await transport.start();
130+
131+
const mockReq = createMockRequest({
132+
host: 'localhost:3000',
133+
'content-type': 'application/json',
134+
});
135+
const mockHandleRes = createMockResponse();
136+
137+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
138+
139+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
140+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
141+
});
142+
143+
it('should reject requests with disallowed host headers', async () => {
144+
const mockRes = createMockResponse();
145+
const transport = new SSEServerTransport('/messages', mockRes, {
146+
allowedHosts: ['localhost:3000'],
147+
});
148+
await transport.start();
149+
150+
const mockReq = createMockRequest({
151+
host: 'evil.com',
152+
'content-type': 'application/json',
153+
});
154+
const mockHandleRes = createMockResponse();
155+
156+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
157+
158+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
159+
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
160+
});
161+
162+
it('should reject requests without host header when allowedHosts is configured', async () => {
163+
const mockRes = createMockResponse();
164+
const transport = new SSEServerTransport('/messages', mockRes, {
165+
allowedHosts: ['localhost:3000'],
166+
});
167+
await transport.start();
168+
169+
const mockReq = createMockRequest({
170+
'content-type': 'application/json',
171+
});
172+
const mockHandleRes = createMockResponse();
173+
174+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
175+
176+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
177+
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined');
178+
});
179+
});
180+
181+
describe('Origin header validation', () => {
182+
it('should accept requests with allowed origin headers', async () => {
183+
const mockRes = createMockResponse();
184+
const transport = new SSEServerTransport('/messages', mockRes, {
185+
allowedOrigins: ['http://localhost:3000', 'https://example.com'],
186+
});
187+
await transport.start();
188+
189+
const mockReq = createMockRequest({
190+
origin: 'http://localhost:3000',
191+
'content-type': 'application/json',
192+
});
193+
const mockHandleRes = createMockResponse();
194+
195+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
196+
197+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
198+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
199+
});
200+
201+
it('should reject requests with disallowed origin headers', async () => {
202+
const mockRes = createMockResponse();
203+
const transport = new SSEServerTransport('/messages', mockRes, {
204+
allowedOrigins: ['http://localhost:3000'],
205+
});
206+
await transport.start();
207+
208+
const mockReq = createMockRequest({
209+
origin: 'http://evil.com',
210+
'content-type': 'application/json',
211+
});
212+
const mockHandleRes = createMockResponse();
213+
214+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
215+
216+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
217+
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
218+
});
219+
});
220+
221+
describe('Content-Type validation', () => {
222+
it('should accept requests with application/json content-type', async () => {
223+
const mockRes = createMockResponse();
224+
const transport = new SSEServerTransport('/messages', mockRes);
225+
await transport.start();
226+
227+
const mockReq = createMockRequest({
228+
'content-type': 'application/json',
229+
});
230+
const mockHandleRes = createMockResponse();
231+
232+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
233+
234+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
235+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
236+
});
237+
238+
it('should accept requests with application/json with charset', async () => {
239+
const mockRes = createMockResponse();
240+
const transport = new SSEServerTransport('/messages', mockRes);
241+
await transport.start();
242+
243+
const mockReq = createMockRequest({
244+
'content-type': 'application/json; charset=utf-8',
245+
});
246+
const mockHandleRes = createMockResponse();
247+
248+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
249+
250+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
251+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
252+
});
253+
254+
it('should reject requests with non-application/json content-type when protection is enabled', async () => {
255+
const mockRes = createMockResponse();
256+
const transport = new SSEServerTransport('/messages', mockRes);
257+
await transport.start();
258+
259+
const mockReq = createMockRequest({
260+
'content-type': 'text/plain',
261+
});
262+
const mockHandleRes = createMockResponse();
263+
264+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
265+
266+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
267+
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Content-Type must start with application/json, got: text/plain');
268+
});
269+
});
270+
271+
describe('disableDnsRebindingProtection option', () => {
272+
it('should skip all validations when disableDnsRebindingProtection is true', async () => {
273+
const mockRes = createMockResponse();
274+
const transport = new SSEServerTransport('/messages', mockRes, {
275+
allowedHosts: ['localhost:3000'],
276+
allowedOrigins: ['http://localhost:3000'],
277+
disableDnsRebindingProtection: true,
278+
});
279+
await transport.start();
280+
281+
const mockReq = createMockRequest({
282+
host: 'evil.com',
283+
origin: 'http://evil.com',
284+
'content-type': 'text/plain',
285+
});
286+
const mockHandleRes = createMockResponse();
287+
288+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
289+
290+
// Should pass even with invalid headers because protection is disabled
291+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
292+
// The error should be from content-type parsing, not DNS rebinding protection
293+
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
294+
});
295+
});
296+
297+
describe('Combined validations', () => {
298+
it('should validate both host and origin when both are configured', async () => {
299+
const mockRes = createMockResponse();
300+
const transport = new SSEServerTransport('/messages', mockRes, {
301+
allowedHosts: ['localhost:3000'],
302+
allowedOrigins: ['http://localhost:3000'],
303+
});
304+
await transport.start();
305+
306+
// Valid host, invalid origin
307+
const mockReq1 = createMockRequest({
308+
host: 'localhost:3000',
309+
origin: 'http://evil.com',
310+
'content-type': 'application/json',
311+
});
312+
const mockHandleRes1 = createMockResponse();
313+
314+
await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' });
315+
316+
expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403);
317+
expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
318+
319+
// Invalid host, valid origin
320+
const mockReq2 = createMockRequest({
321+
host: 'evil.com',
322+
origin: 'http://localhost:3000',
323+
'content-type': 'application/json',
324+
});
325+
const mockHandleRes2 = createMockResponse();
326+
327+
await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' });
328+
329+
expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403);
330+
expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
331+
332+
// Both valid
333+
const mockReq3 = createMockRequest({
334+
host: 'localhost:3000',
335+
origin: 'http://localhost:3000',
336+
'content-type': 'application/json',
337+
});
338+
const mockHandleRes3 = createMockResponse();
339+
340+
await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' });
341+
342+
expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202);
343+
expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted');
344+
});
345+
});
346+
});
109347
});

src/server/sse.ts

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@ import { URL } from 'url';
99

1010
const MAXIMUM_MESSAGE_SIZE = "4mb";
1111

12+
/**
13+
* Configuration options for SSEServerTransport.
14+
*/
15+
export interface SSEServerTransportOptions {
16+
/**
17+
* List of allowed host header values for DNS rebinding protection.
18+
* If not specified, host validation is disabled.
19+
*/
20+
allowedHosts?: string[];
21+
22+
/**
23+
* List of allowed origin header values for DNS rebinding protection.
24+
* If not specified, origin validation is disabled.
25+
*/
26+
allowedOrigins?: string[];
27+
28+
/**
29+
* Disable DNS rebinding protection entirely (overrides allowedHosts and allowedOrigins).
30+
* Default is false.
31+
*/
32+
disableDnsRebindingProtection?: boolean;
33+
}
34+
1235
/**
1336
* Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests.
1437
*
@@ -17,6 +40,7 @@ const MAXIMUM_MESSAGE_SIZE = "4mb";
1740
export class SSEServerTransport implements Transport {
1841
private _sseResponse?: ServerResponse;
1942
private _sessionId: string;
43+
private _options: SSEServerTransportOptions;
2044

2145
onclose?: () => void;
2246
onerror?: (error: Error) => void;
@@ -28,8 +52,39 @@ export class SSEServerTransport implements Transport {
2852
constructor(
2953
private _endpoint: string,
3054
private res: ServerResponse,
55+
options?: SSEServerTransportOptions,
3156
) {
3257
this._sessionId = randomUUID();
58+
this._options = options || {disableDnsRebindingProtection: true};
59+
}
60+
61+
/**
62+
* Validates request headers for DNS rebinding protection.
63+
* @returns Error message if validation fails, undefined if validation passes.
64+
*/
65+
private validateRequestHeaders(req: IncomingMessage): string | undefined {
66+
// Skip validation if protection is disabled
67+
if (this._options.disableDnsRebindingProtection) {
68+
return undefined;
69+
}
70+
71+
// Validate Host header if allowedHosts is configured
72+
if (this._options.allowedHosts && this._options.allowedHosts.length > 0) {
73+
const hostHeader = req.headers.host;
74+
if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) {
75+
return `Invalid Host header: ${hostHeader}`;
76+
}
77+
}
78+
79+
// Validate Origin header if allowedOrigins is configured
80+
if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) {
81+
const originHeader = req.headers.origin;
82+
if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) {
83+
return `Invalid Origin header: ${originHeader}`;
84+
}
85+
}
86+
87+
return undefined;
3388
}
3489

3590
/**
@@ -86,13 +141,22 @@ export class SSEServerTransport implements Transport {
86141
res.writeHead(500).end(message);
87142
throw new Error(message);
88143
}
144+
145+
// Validate request headers for DNS rebinding protection
146+
const validationError = this.validateRequestHeaders(req);
147+
if (validationError) {
148+
res.writeHead(403).end(validationError);
149+
this.onerror?.(new Error(validationError));
150+
return;
151+
}
152+
89153
const authInfo: AuthInfo | undefined = req.auth;
90154

91155
let body: string | unknown;
92156
try {
93157
const ct = contentType.parse(req.headers["content-type"] ?? "");
94158
if (ct.type !== "application/json") {
95-
throw new Error(`Unsupported content-type: ${ct}`);
159+
throw new Error(`Unsupported content-type: ${ct.type}`);
96160
}
97161

98162
body = parsedBody ?? await getRawBody(req, {

0 commit comments

Comments
 (0)