@@ -1293,4 +1293,265 @@ describe("StreamableHTTPServerTransport in stateless mode", () => {
1293
1293
} ) ;
1294
1294
expect ( stream2 . status ) . toBe ( 409 ) ; // Conflict - only one stream allowed
1295
1295
} ) ;
1296
- } ) ;
1296
+ } ) ;
1297
+
1298
+ // Test DNS rebinding protection
1299
+ describe ( "StreamableHTTPServerTransport DNS rebinding protection" , ( ) => {
1300
+ let server : Server ;
1301
+ let transport : StreamableHTTPServerTransport ;
1302
+ let baseUrl : URL ;
1303
+
1304
+ afterEach ( async ( ) => {
1305
+ if ( server && transport ) {
1306
+ await stopTestServer ( { server, transport } ) ;
1307
+ }
1308
+ } ) ;
1309
+
1310
+ describe ( "Host header validation" , ( ) => {
1311
+ it ( "should accept requests with allowed host headers" , async ( ) => {
1312
+ const result = await createTestServerWithDnsProtection ( {
1313
+ sessionIdGenerator : undefined ,
1314
+ allowedHosts : [ 'localhost:3001' ] ,
1315
+ disableDnsRebindingProtection : false ,
1316
+ } ) ;
1317
+ server = result . server ;
1318
+ transport = result . transport ;
1319
+ baseUrl = result . baseUrl ;
1320
+
1321
+ // Note: fetch() automatically sets Host header to match the URL
1322
+ // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work
1323
+ const response = await fetch ( baseUrl , {
1324
+ method : "POST" ,
1325
+ headers : {
1326
+ "Content-Type" : "application/json" ,
1327
+ Accept : "application/json, text/event-stream" ,
1328
+ } ,
1329
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1330
+ } ) ;
1331
+
1332
+ expect ( response . status ) . toBe ( 200 ) ;
1333
+ } ) ;
1334
+
1335
+ it ( "should reject requests with disallowed host headers" , async ( ) => {
1336
+ // Test DNS rebinding protection by creating a server that only allows example.com
1337
+ // but we're connecting via localhost, so it should be rejected
1338
+ const result = await createTestServerWithDnsProtection ( {
1339
+ sessionIdGenerator : undefined ,
1340
+ allowedHosts : [ 'example.com:3001' ] ,
1341
+ disableDnsRebindingProtection : false ,
1342
+ } ) ;
1343
+ server = result . server ;
1344
+ transport = result . transport ;
1345
+ baseUrl = result . baseUrl ;
1346
+
1347
+ const response = await fetch ( baseUrl , {
1348
+ method : "POST" ,
1349
+ headers : {
1350
+ "Content-Type" : "application/json" ,
1351
+ Accept : "application/json, text/event-stream" ,
1352
+ } ,
1353
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1354
+ } ) ;
1355
+
1356
+ expect ( response . status ) . toBe ( 403 ) ;
1357
+ const body = await response . json ( ) ;
1358
+ expect ( body . error . message ) . toContain ( "Invalid Host header:" ) ;
1359
+ } ) ;
1360
+
1361
+ it ( "should reject GET requests with disallowed host headers" , async ( ) => {
1362
+ const result = await createTestServerWithDnsProtection ( {
1363
+ sessionIdGenerator : undefined ,
1364
+ allowedHosts : [ 'example.com:3001' ] ,
1365
+ disableDnsRebindingProtection : false ,
1366
+ } ) ;
1367
+ server = result . server ;
1368
+ transport = result . transport ;
1369
+ baseUrl = result . baseUrl ;
1370
+
1371
+ const response = await fetch ( baseUrl , {
1372
+ method : "GET" ,
1373
+ headers : {
1374
+ Accept : "text/event-stream" ,
1375
+ } ,
1376
+ } ) ;
1377
+
1378
+ expect ( response . status ) . toBe ( 403 ) ;
1379
+ } ) ;
1380
+ } ) ;
1381
+
1382
+ describe ( "Origin header validation" , ( ) => {
1383
+ it ( "should accept requests with allowed origin headers" , async ( ) => {
1384
+ const result = await createTestServerWithDnsProtection ( {
1385
+ sessionIdGenerator : undefined ,
1386
+ allowedOrigins : [ 'http://localhost:3000' , 'https://example.com' ] ,
1387
+ disableDnsRebindingProtection : false ,
1388
+ } ) ;
1389
+ server = result . server ;
1390
+ transport = result . transport ;
1391
+ baseUrl = result . baseUrl ;
1392
+
1393
+ const response = await fetch ( baseUrl , {
1394
+ method : "POST" ,
1395
+ headers : {
1396
+ "Content-Type" : "application/json" ,
1397
+ Accept : "application/json, text/event-stream" ,
1398
+ Origin : "http://localhost:3000" ,
1399
+ } ,
1400
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1401
+ } ) ;
1402
+
1403
+ expect ( response . status ) . toBe ( 200 ) ;
1404
+ } ) ;
1405
+
1406
+ it ( "should reject requests with disallowed origin headers" , async ( ) => {
1407
+ const result = await createTestServerWithDnsProtection ( {
1408
+ sessionIdGenerator : undefined ,
1409
+ allowedOrigins : [ 'http://localhost:3000' ] ,
1410
+ disableDnsRebindingProtection : false ,
1411
+ } ) ;
1412
+ server = result . server ;
1413
+ transport = result . transport ;
1414
+ baseUrl = result . baseUrl ;
1415
+
1416
+ const response = await fetch ( baseUrl , {
1417
+ method : "POST" ,
1418
+ headers : {
1419
+ "Content-Type" : "application/json" ,
1420
+ Accept : "application/json, text/event-stream" ,
1421
+ Origin : "http://evil.com" ,
1422
+ } ,
1423
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1424
+ } ) ;
1425
+
1426
+ expect ( response . status ) . toBe ( 403 ) ;
1427
+ const body = await response . json ( ) ;
1428
+ expect ( body . error . message ) . toBe ( "Invalid Origin header: http://evil.com" ) ;
1429
+ } ) ;
1430
+ } ) ;
1431
+
1432
+ describe ( "disableDnsRebindingProtection option" , ( ) => {
1433
+ it ( "should skip all validations when disableDnsRebindingProtection is true" , async ( ) => {
1434
+ const result = await createTestServerWithDnsProtection ( {
1435
+ sessionIdGenerator : undefined ,
1436
+ allowedHosts : [ 'localhost:3001' ] ,
1437
+ allowedOrigins : [ 'http://localhost:3000' ] ,
1438
+ disableDnsRebindingProtection : true ,
1439
+ } ) ;
1440
+ server = result . server ;
1441
+ transport = result . transport ;
1442
+ baseUrl = result . baseUrl ;
1443
+
1444
+ const response = await fetch ( baseUrl , {
1445
+ method : "POST" ,
1446
+ headers : {
1447
+ "Content-Type" : "application/json" ,
1448
+ Accept : "application/json, text/event-stream" ,
1449
+ Host : "evil.com" ,
1450
+ Origin : "http://evil.com" ,
1451
+ } ,
1452
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1453
+ } ) ;
1454
+
1455
+ // Should pass even with invalid headers because protection is disabled
1456
+ expect ( response . status ) . toBe ( 200 ) ;
1457
+ } ) ;
1458
+ } ) ;
1459
+
1460
+ describe ( "Combined validations" , ( ) => {
1461
+ it ( "should validate both host and origin when both are configured" , async ( ) => {
1462
+ const result = await createTestServerWithDnsProtection ( {
1463
+ sessionIdGenerator : undefined ,
1464
+ allowedHosts : [ 'localhost:3001' ] ,
1465
+ allowedOrigins : [ 'http://localhost:3001' ] ,
1466
+ disableDnsRebindingProtection : false ,
1467
+ } ) ;
1468
+ server = result . server ;
1469
+ transport = result . transport ;
1470
+ baseUrl = result . baseUrl ;
1471
+
1472
+ // Test with invalid origin (host will be automatically correct via fetch)
1473
+ const response1 = await fetch ( baseUrl , {
1474
+ method : "POST" ,
1475
+ headers : {
1476
+ "Content-Type" : "application/json" ,
1477
+ Accept : "application/json, text/event-stream" ,
1478
+ Origin : "http://evil.com" ,
1479
+ } ,
1480
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1481
+ } ) ;
1482
+
1483
+ expect ( response1 . status ) . toBe ( 403 ) ;
1484
+ const body1 = await response1 . json ( ) ;
1485
+ expect ( body1 . error . message ) . toBe ( "Invalid Origin header: http://evil.com" ) ;
1486
+
1487
+ // Test with valid origin
1488
+ const response2 = await fetch ( baseUrl , {
1489
+ method : "POST" ,
1490
+ headers : {
1491
+ "Content-Type" : "application/json" ,
1492
+ Accept : "application/json, text/event-stream" ,
1493
+ Origin : "http://localhost:3001" ,
1494
+ } ,
1495
+ body : JSON . stringify ( TEST_MESSAGES . initialize ) ,
1496
+ } ) ;
1497
+
1498
+ expect ( response2 . status ) . toBe ( 200 ) ;
1499
+ } ) ;
1500
+ } ) ;
1501
+ } ) ;
1502
+
1503
+ /**
1504
+ * Helper to create test server with DNS rebinding protection options
1505
+ */
1506
+ async function createTestServerWithDnsProtection ( config : {
1507
+ sessionIdGenerator : ( ( ) => string ) | undefined ;
1508
+ allowedHosts ?: string [ ] ;
1509
+ allowedOrigins ?: string [ ] ;
1510
+ disableDnsRebindingProtection ?: boolean ;
1511
+ } ) : Promise < {
1512
+ server : Server ;
1513
+ transport : StreamableHTTPServerTransport ;
1514
+ mcpServer : McpServer ;
1515
+ baseUrl : URL ;
1516
+ } > {
1517
+ const mcpServer = new McpServer (
1518
+ { name : "test-server" , version : "1.0.0" } ,
1519
+ { capabilities : { logging : { } } }
1520
+ ) ;
1521
+
1522
+ const transport = new StreamableHTTPServerTransport ( {
1523
+ sessionIdGenerator : config . sessionIdGenerator ,
1524
+ allowedHosts : config . allowedHosts ,
1525
+ allowedOrigins : config . allowedOrigins ,
1526
+ disableDnsRebindingProtection : config . disableDnsRebindingProtection ,
1527
+ } ) ;
1528
+
1529
+ await mcpServer . connect ( transport ) ;
1530
+
1531
+ const httpServer = createServer ( async ( req , res ) => {
1532
+ if ( req . method === "POST" ) {
1533
+ let body = "" ;
1534
+ req . on ( "data" , ( chunk ) => ( body += chunk ) ) ;
1535
+ req . on ( "end" , async ( ) => {
1536
+ const parsedBody = JSON . parse ( body ) ;
1537
+ await transport . handleRequest ( req as IncomingMessage & { auth ?: AuthInfo } , res , parsedBody ) ;
1538
+ } ) ;
1539
+ } else {
1540
+ await transport . handleRequest ( req as IncomingMessage & { auth ?: AuthInfo } , res ) ;
1541
+ }
1542
+ } ) ;
1543
+
1544
+ await new Promise < void > ( ( resolve ) => {
1545
+ httpServer . listen ( 3001 , ( ) => resolve ( ) ) ;
1546
+ } ) ;
1547
+
1548
+ const port = ( httpServer . address ( ) as AddressInfo ) . port ;
1549
+ const serverUrl = new URL ( `http://localhost:${ port } /` ) ;
1550
+
1551
+ return {
1552
+ server : httpServer ,
1553
+ transport,
1554
+ mcpServer,
1555
+ baseUrl : serverUrl ,
1556
+ } ;
1557
+ }
0 commit comments