Skip to content

Commit 9758c23

Browse files
committed
Add support for DNS rebinding protections
1 parent 07e7b8f commit 9758c23

File tree

6 files changed

+702
-5
lines changed

6 files changed

+702
-5
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.modelcontextprotocol.server.transport;
22

33
import java.io.IOException;
4+
import java.util.List;
45
import java.util.Map;
56
import java.util.concurrent.ConcurrentHashMap;
67

@@ -110,6 +111,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
110111
*/
111112
private volatile boolean isClosing = false;
112113

114+
/**
115+
* DNS rebinding protection configuration.
116+
*/
117+
private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig;
118+
113119
/**
114120
* Constructs a new WebFlux SSE server transport provider instance with the default
115121
* SSE endpoint.
@@ -134,7 +140,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
134140
* @throws IllegalArgumentException if either parameter is null
135141
*/
136142
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
137-
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
143+
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null);
138144
}
139145

140146
/**
@@ -149,6 +155,24 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
149155
*/
150156
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
151157
String sseEndpoint) {
158+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
159+
}
160+
161+
/**
162+
* Constructs a new WebFlux SSE server transport provider instance with optional DNS
163+
* rebinding protection.
164+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
165+
* of MCP messages. Must not be null.
166+
* @param baseUrl webflux message base path
167+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
168+
* messages. This endpoint will be communicated to clients during SSE connection
169+
* setup. Must not be null.
170+
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
171+
* @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null).
172+
* @throws IllegalArgumentException if required parameters are null
173+
*/
174+
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
175+
String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) {
152176
Assert.notNull(objectMapper, "ObjectMapper must not be null");
153177
Assert.notNull(baseUrl, "Message base path must not be null");
154178
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
@@ -158,6 +182,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
158182
this.baseUrl = baseUrl;
159183
this.messageEndpoint = messageEndpoint;
160184
this.sseEndpoint = sseEndpoint;
185+
this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig;
161186
this.routerFunction = RouterFunctions.route()
162187
.GET(this.sseEndpoint, this::handleSseConnection)
163188
.POST(this.messageEndpoint, this::handleMessage)
@@ -256,6 +281,16 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
256281
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
257282
}
258283

284+
// Validate headers
285+
if (dnsRebindingProtectionConfig != null) {
286+
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
287+
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
288+
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
289+
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
290+
return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed");
291+
}
292+
}
293+
259294
return ServerResponse.ok()
260295
.contentType(MediaType.TEXT_EVENT_STREAM)
261296
.body(Flux.<ServerSentEvent<?>>create(sink -> {
@@ -300,6 +335,25 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
300335
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
301336
}
302337

338+
// Always validate Content-Type for POST requests
339+
String contentType = request.headers().contentType()
340+
.map(MediaType::toString)
341+
.orElse(null);
342+
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) {
343+
logger.warn("Invalid Content-Type header: '{}'", contentType);
344+
return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json"));
345+
}
346+
347+
// Validate headers for POST requests if DNS rebinding protection is configured
348+
if (dnsRebindingProtectionConfig != null) {
349+
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
350+
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
351+
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
352+
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
353+
return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed");
354+
}
355+
}
356+
303357
if (request.queryParam("sessionId").isEmpty()) {
304358
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
305359
}
@@ -397,6 +451,8 @@ public static class Builder {
397451

398452
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
399453

454+
private DnsRebindingProtectionConfig dnsRebindingProtectionConfig;
455+
400456
/**
401457
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
402458
* messages.
@@ -447,6 +503,23 @@ public Builder sseEndpoint(String sseEndpoint) {
447503
return this;
448504
}
449505

506+
507+
/**
508+
* Sets the DNS rebinding protection configuration.
509+
* <p>
510+
* When set, this configuration will be used to create a header validator that
511+
* enforces DNS rebinding protection rules. This will override any previously set
512+
* header validator.
513+
* @param config The DNS rebinding protection configuration
514+
* @return this builder instance
515+
* @throws IllegalArgumentException if config is null
516+
*/
517+
public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) {
518+
Assert.notNull(config, "DNS rebinding protection config must not be null");
519+
this.dnsRebindingProtectionConfig = config;
520+
return this;
521+
}
522+
450523
/**
451524
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
452525
* configured settings.
@@ -457,7 +530,8 @@ public WebFluxSseServerTransportProvider build() {
457530
Assert.notNull(objectMapper, "ObjectMapper must be set");
458531
Assert.notNull(messageEndpoint, "Message endpoint must be set");
459532

460-
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
533+
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
534+
dnsRebindingProtectionConfig);
461535
}
462536

463537
}

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import java.io.IOException;
88
import java.time.Duration;
9+
import java.util.List;
910
import java.util.Map;
1011
import java.util.UUID;
1112
import java.util.concurrent.ConcurrentHashMap;
@@ -107,6 +108,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
107108
*/
108109
private volatile boolean isClosing = false;
109110

111+
/**
112+
* DNS rebinding protection configuration.
113+
*/
114+
private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig;
115+
110116
/**
111117
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
112118
* endpoint.
@@ -132,7 +138,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
132138
* @throws IllegalArgumentException if any parameter is null
133139
*/
134140
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
135-
this(objectMapper, "", messageEndpoint, sseEndpoint);
141+
this(objectMapper, "", messageEndpoint, sseEndpoint, null);
136142
}
137143

138144
/**
@@ -149,6 +155,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
149155
*/
150156
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
151157
String sseEndpoint) {
158+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
159+
}
160+
161+
/**
162+
* Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding protection.
163+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
164+
* of messages.
165+
* @param baseUrl The base URL for the message endpoint, used to construct the full
166+
* endpoint URL for clients.
167+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
168+
* messages via HTTP POST. This endpoint will be communicated to clients through the
169+
* SSE connection's initial endpoint event.
170+
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
171+
* @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null).
172+
* @throws IllegalArgumentException if any required parameter is null
173+
*/
174+
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
175+
String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) {
152176
Assert.notNull(objectMapper, "ObjectMapper must not be null");
153177
Assert.notNull(baseUrl, "Message base URL must not be null");
154178
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
@@ -158,6 +182,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
158182
this.baseUrl = baseUrl;
159183
this.messageEndpoint = messageEndpoint;
160184
this.sseEndpoint = sseEndpoint;
185+
this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig;
161186
this.routerFunction = RouterFunctions.route()
162187
.GET(this.sseEndpoint, this::handleSseConnection)
163188
.POST(this.messageEndpoint, this::handleMessage)
@@ -247,6 +272,16 @@ private ServerResponse handleSseConnection(ServerRequest request) {
247272
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
248273
}
249274

275+
// Validate headers
276+
if (dnsRebindingProtectionConfig != null) {
277+
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
278+
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
279+
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
280+
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
281+
return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed");
282+
}
283+
}
284+
250285
String sessionId = UUID.randomUUID().toString();
251286
logger.debug("Creating new SSE connection for session: {}", sessionId);
252287

@@ -300,6 +335,23 @@ private ServerResponse handleMessage(ServerRequest request) {
300335
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
301336
}
302337

338+
// Always validate Content-Type for POST requests
339+
String contentType = request.headers().asHttpHeaders().getFirst("Content-Type");
340+
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) {
341+
logger.warn("Invalid Content-Type header: '{}'", contentType);
342+
return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json"));
343+
}
344+
345+
// Validate headers for POST requests if DNS rebinding protection is configured
346+
if (dnsRebindingProtectionConfig != null) {
347+
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
348+
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
349+
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
350+
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
351+
return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed");
352+
}
353+
}
354+
303355
if (request.param("sessionId").isEmpty()) {
304356
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
305357
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package io.modelcontextprotocol.server.transport;
2+
3+
import java.util.Collections;
4+
import java.util.HashSet;
5+
import java.util.Set;
6+
7+
/**
8+
* Configuration for DNS rebinding protection in SSE server transports. Provides
9+
* validation for Host and Origin headers to prevent DNS rebinding attacks.
10+
*/
11+
public class DnsRebindingProtectionConfig {
12+
13+
private final Set<String> allowedHosts;
14+
15+
private final Set<String> allowedOrigins;
16+
17+
private final boolean enableDnsRebindingProtection;
18+
19+
private DnsRebindingProtectionConfig(Builder builder) {
20+
this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(builder.allowedHosts));
21+
this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(builder.allowedOrigins));
22+
this.enableDnsRebindingProtection = builder.enableDnsRebindingProtection;
23+
}
24+
25+
/**
26+
* Validates Host and Origin headers for DNS rebinding protection. Returns true if the
27+
* headers are valid, false otherwise.
28+
* @param hostHeader The value of the Host header (may be null)
29+
* @param originHeader The value of the Origin header (may be null)
30+
* @return true if the headers are valid, false otherwise
31+
*/
32+
public boolean validate(String hostHeader, String originHeader) {
33+
// Skip validation if protection is not enabled
34+
if (!enableDnsRebindingProtection) {
35+
return true;
36+
}
37+
38+
// Validate Host header
39+
if (hostHeader != null) {
40+
String lowerHost = hostHeader.toLowerCase();
41+
if (!allowedHosts.contains(lowerHost)) {
42+
return false;
43+
}
44+
}
45+
46+
// Validate Origin header
47+
if (originHeader != null) {
48+
String lowerOrigin = originHeader.toLowerCase();
49+
if (!allowedOrigins.contains(lowerOrigin)) {
50+
return false;
51+
}
52+
}
53+
54+
return true;
55+
}
56+
57+
public static Builder builder() {
58+
return new Builder();
59+
}
60+
61+
public static class Builder {
62+
63+
private final Set<String> allowedHosts = new HashSet<>();
64+
65+
private final Set<String> allowedOrigins = new HashSet<>();
66+
67+
private boolean enableDnsRebindingProtection = true;
68+
69+
public Builder allowedHost(String host) {
70+
if (host != null) {
71+
this.allowedHosts.add(host.toLowerCase());
72+
}
73+
return this;
74+
}
75+
76+
public Builder allowedHosts(Set<String> hosts) {
77+
if (hosts != null) {
78+
hosts.forEach(this::allowedHost);
79+
}
80+
return this;
81+
}
82+
83+
public Builder allowedOrigin(String origin) {
84+
if (origin != null) {
85+
this.allowedOrigins.add(origin.toLowerCase());
86+
}
87+
return this;
88+
}
89+
90+
public Builder allowedOrigins(Set<String> origins) {
91+
if (origins != null) {
92+
origins.forEach(this::allowedOrigin);
93+
}
94+
return this;
95+
}
96+
97+
public Builder enableDnsRebindingProtection(boolean enable) {
98+
this.enableDnsRebindingProtection = enable;
99+
return this;
100+
}
101+
102+
public DnsRebindingProtectionConfig build() {
103+
return new DnsRebindingProtectionConfig(this);
104+
}
105+
106+
}
107+
108+
}

0 commit comments

Comments
 (0)