Skip to content

Commit ebf2535

Browse files
committed
Add protections for streamable HTTP too
1 parent adbacc6 commit ebf2535

File tree

3 files changed

+351
-2
lines changed

3 files changed

+351
-2
lines changed

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,11 @@ app.post('/mcp', async (req, res) => {
251251
onsessioninitialized: (sessionId) => {
252252
// Store the transport by session ID
253253
transports[sessionId] = transport;
254-
}
254+
},
255+
// DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server
256+
// locally, make sure to set:
257+
// disableDnsRebindingProtection: true,
258+
// allowedHosts: ['127.0.0.1'],
255259
});
256260

257261
// Clean up transport when closed
@@ -386,6 +390,22 @@ This stateless approach is useful for:
386390
- RESTful scenarios where each request is independent
387391
- Horizontally scaled deployments without shared session state
388392

393+
#### DNS Rebinding Protection
394+
395+
The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility.
396+
397+
**Important**: If you are running this server locally, enable DNS rebinding protection:
398+
399+
```typescript
400+
const transport = new StreamableHTTPServerTransport({
401+
sessionIdGenerator: () => randomUUID(),
402+
disableDnsRebindingProtection: false,
403+
404+
allowedHosts: ['127.0.0.1', ...],
405+
allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com']
406+
});
407+
```
408+
389409
### Testing and Debugging
390410

391411
To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information.

src/server/streamableHttp.test.ts

Lines changed: 262 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,4 +1293,265 @@ describe("StreamableHTTPServerTransport in stateless mode", () => {
12931293
});
12941294
expect(stream2.status).toBe(409); // Conflict - only one stream allowed
12951295
});
1296-
});
1296+
});
1297+
1298+
// Test DNS rebinding protection
1299+
describe("StreamableHTTPServerTransport DNS rebinding protection", () => {
1300+
let server: Server;
1301+
let transport: StreamableHTTPServerTransport;
1302+
let baseUrl: URL;
1303+
1304+
afterEach(async () => {
1305+
if (server && transport) {
1306+
await stopTestServer({ server, transport });
1307+
}
1308+
});
1309+
1310+
describe("Host header validation", () => {
1311+
it("should accept requests with allowed host headers", async () => {
1312+
const result = await createTestServerWithDnsProtection({
1313+
sessionIdGenerator: undefined,
1314+
allowedHosts: ['localhost:3001'],
1315+
disableDnsRebindingProtection: false,
1316+
});
1317+
server = result.server;
1318+
transport = result.transport;
1319+
baseUrl = result.baseUrl;
1320+
1321+
// Note: fetch() automatically sets Host header to match the URL
1322+
// Since we're connecting to localhost:3001 and that's in allowedHosts, this should work
1323+
const response = await fetch(baseUrl, {
1324+
method: "POST",
1325+
headers: {
1326+
"Content-Type": "application/json",
1327+
Accept: "application/json, text/event-stream",
1328+
},
1329+
body: JSON.stringify(TEST_MESSAGES.initialize),
1330+
});
1331+
1332+
expect(response.status).toBe(200);
1333+
});
1334+
1335+
it("should reject requests with disallowed host headers", async () => {
1336+
// Test DNS rebinding protection by creating a server that only allows example.com
1337+
// but we're connecting via localhost, so it should be rejected
1338+
const result = await createTestServerWithDnsProtection({
1339+
sessionIdGenerator: undefined,
1340+
allowedHosts: ['example.com:3001'],
1341+
disableDnsRebindingProtection: false,
1342+
});
1343+
server = result.server;
1344+
transport = result.transport;
1345+
baseUrl = result.baseUrl;
1346+
1347+
const response = await fetch(baseUrl, {
1348+
method: "POST",
1349+
headers: {
1350+
"Content-Type": "application/json",
1351+
Accept: "application/json, text/event-stream",
1352+
},
1353+
body: JSON.stringify(TEST_MESSAGES.initialize),
1354+
});
1355+
1356+
expect(response.status).toBe(403);
1357+
const body = await response.json();
1358+
expect(body.error.message).toContain("Invalid Host header:");
1359+
});
1360+
1361+
it("should reject GET requests with disallowed host headers", async () => {
1362+
const result = await createTestServerWithDnsProtection({
1363+
sessionIdGenerator: undefined,
1364+
allowedHosts: ['example.com:3001'],
1365+
disableDnsRebindingProtection: false,
1366+
});
1367+
server = result.server;
1368+
transport = result.transport;
1369+
baseUrl = result.baseUrl;
1370+
1371+
const response = await fetch(baseUrl, {
1372+
method: "GET",
1373+
headers: {
1374+
Accept: "text/event-stream",
1375+
},
1376+
});
1377+
1378+
expect(response.status).toBe(403);
1379+
});
1380+
});
1381+
1382+
describe("Origin header validation", () => {
1383+
it("should accept requests with allowed origin headers", async () => {
1384+
const result = await createTestServerWithDnsProtection({
1385+
sessionIdGenerator: undefined,
1386+
allowedOrigins: ['http://localhost:3000', 'https://example.com'],
1387+
disableDnsRebindingProtection: false,
1388+
});
1389+
server = result.server;
1390+
transport = result.transport;
1391+
baseUrl = result.baseUrl;
1392+
1393+
const response = await fetch(baseUrl, {
1394+
method: "POST",
1395+
headers: {
1396+
"Content-Type": "application/json",
1397+
Accept: "application/json, text/event-stream",
1398+
Origin: "http://localhost:3000",
1399+
},
1400+
body: JSON.stringify(TEST_MESSAGES.initialize),
1401+
});
1402+
1403+
expect(response.status).toBe(200);
1404+
});
1405+
1406+
it("should reject requests with disallowed origin headers", async () => {
1407+
const result = await createTestServerWithDnsProtection({
1408+
sessionIdGenerator: undefined,
1409+
allowedOrigins: ['http://localhost:3000'],
1410+
disableDnsRebindingProtection: false,
1411+
});
1412+
server = result.server;
1413+
transport = result.transport;
1414+
baseUrl = result.baseUrl;
1415+
1416+
const response = await fetch(baseUrl, {
1417+
method: "POST",
1418+
headers: {
1419+
"Content-Type": "application/json",
1420+
Accept: "application/json, text/event-stream",
1421+
Origin: "http://evil.com",
1422+
},
1423+
body: JSON.stringify(TEST_MESSAGES.initialize),
1424+
});
1425+
1426+
expect(response.status).toBe(403);
1427+
const body = await response.json();
1428+
expect(body.error.message).toBe("Invalid Origin header: http://evil.com");
1429+
});
1430+
});
1431+
1432+
describe("disableDnsRebindingProtection option", () => {
1433+
it("should skip all validations when disableDnsRebindingProtection is true", async () => {
1434+
const result = await createTestServerWithDnsProtection({
1435+
sessionIdGenerator: undefined,
1436+
allowedHosts: ['localhost:3001'],
1437+
allowedOrigins: ['http://localhost:3000'],
1438+
disableDnsRebindingProtection: true,
1439+
});
1440+
server = result.server;
1441+
transport = result.transport;
1442+
baseUrl = result.baseUrl;
1443+
1444+
const response = await fetch(baseUrl, {
1445+
method: "POST",
1446+
headers: {
1447+
"Content-Type": "application/json",
1448+
Accept: "application/json, text/event-stream",
1449+
Host: "evil.com",
1450+
Origin: "http://evil.com",
1451+
},
1452+
body: JSON.stringify(TEST_MESSAGES.initialize),
1453+
});
1454+
1455+
// Should pass even with invalid headers because protection is disabled
1456+
expect(response.status).toBe(200);
1457+
});
1458+
});
1459+
1460+
describe("Combined validations", () => {
1461+
it("should validate both host and origin when both are configured", async () => {
1462+
const result = await createTestServerWithDnsProtection({
1463+
sessionIdGenerator: undefined,
1464+
allowedHosts: ['localhost:3001'],
1465+
allowedOrigins: ['http://localhost:3001'],
1466+
disableDnsRebindingProtection: false,
1467+
});
1468+
server = result.server;
1469+
transport = result.transport;
1470+
baseUrl = result.baseUrl;
1471+
1472+
// Test with invalid origin (host will be automatically correct via fetch)
1473+
const response1 = await fetch(baseUrl, {
1474+
method: "POST",
1475+
headers: {
1476+
"Content-Type": "application/json",
1477+
Accept: "application/json, text/event-stream",
1478+
Origin: "http://evil.com",
1479+
},
1480+
body: JSON.stringify(TEST_MESSAGES.initialize),
1481+
});
1482+
1483+
expect(response1.status).toBe(403);
1484+
const body1 = await response1.json();
1485+
expect(body1.error.message).toBe("Invalid Origin header: http://evil.com");
1486+
1487+
// Test with valid origin
1488+
const response2 = await fetch(baseUrl, {
1489+
method: "POST",
1490+
headers: {
1491+
"Content-Type": "application/json",
1492+
Accept: "application/json, text/event-stream",
1493+
Origin: "http://localhost:3001",
1494+
},
1495+
body: JSON.stringify(TEST_MESSAGES.initialize),
1496+
});
1497+
1498+
expect(response2.status).toBe(200);
1499+
});
1500+
});
1501+
});
1502+
1503+
/**
1504+
* Helper to create test server with DNS rebinding protection options
1505+
*/
1506+
async function createTestServerWithDnsProtection(config: {
1507+
sessionIdGenerator: (() => string) | undefined;
1508+
allowedHosts?: string[];
1509+
allowedOrigins?: string[];
1510+
disableDnsRebindingProtection?: boolean;
1511+
}): Promise<{
1512+
server: Server;
1513+
transport: StreamableHTTPServerTransport;
1514+
mcpServer: McpServer;
1515+
baseUrl: URL;
1516+
}> {
1517+
const mcpServer = new McpServer(
1518+
{ name: "test-server", version: "1.0.0" },
1519+
{ capabilities: { logging: {} } }
1520+
);
1521+
1522+
const transport = new StreamableHTTPServerTransport({
1523+
sessionIdGenerator: config.sessionIdGenerator,
1524+
allowedHosts: config.allowedHosts,
1525+
allowedOrigins: config.allowedOrigins,
1526+
disableDnsRebindingProtection: config.disableDnsRebindingProtection,
1527+
});
1528+
1529+
await mcpServer.connect(transport);
1530+
1531+
const httpServer = createServer(async (req, res) => {
1532+
if (req.method === "POST") {
1533+
let body = "";
1534+
req.on("data", (chunk) => (body += chunk));
1535+
req.on("end", async () => {
1536+
const parsedBody = JSON.parse(body);
1537+
await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res, parsedBody);
1538+
});
1539+
} else {
1540+
await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res);
1541+
}
1542+
});
1543+
1544+
await new Promise<void>((resolve) => {
1545+
httpServer.listen(3001, () => resolve());
1546+
});
1547+
1548+
const port = (httpServer.address() as AddressInfo).port;
1549+
const serverUrl = new URL(`http://localhost:${port}/`);
1550+
1551+
return {
1552+
server: httpServer,
1553+
transport,
1554+
mcpServer,
1555+
baseUrl: serverUrl,
1556+
};
1557+
}

0 commit comments

Comments
 (0)