Skip to content

Commit 2ef82cf

Browse files
committed
feat: add configurable keep-alive support to streamable server transports
- Add KeepAliveScheduler integration to WebFlux and WebMVC transport providers - Support configurable keep-alive intervals through builder pattern - Automatically shutdown keep-alive scheduler during graceful shutdown Signed-off-by: Christian Tzolov <[email protected]>
1 parent 0465f65 commit 2ef82cf

File tree

3 files changed

+92
-15
lines changed

3 files changed

+92
-15
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
1313
import io.modelcontextprotocol.server.McpTransportContext;
1414
import io.modelcontextprotocol.util.Assert;
15+
import io.modelcontextprotocol.util.KeepAliveScheduler;
16+
1517
import org.slf4j.Logger;
1618
import org.slf4j.LoggerFactory;
1719
import org.springframework.http.HttpStatus;
@@ -28,6 +30,7 @@
2830
import reactor.core.publisher.Mono;
2931

3032
import java.io.IOException;
33+
import java.time.Duration;
3134
import java.util.List;
3235
import java.util.concurrent.ConcurrentHashMap;
3336

@@ -58,8 +61,11 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
5861

5962
private volatile boolean isClosing = false;
6063

64+
private KeepAliveScheduler keepAliveScheduler;
65+
6166
private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
62-
McpTransportContextExtractor<ServerRequest> contextExtractor, boolean disallowDelete) {
67+
McpTransportContextExtractor<ServerRequest> contextExtractor, boolean disallowDelete,
68+
Duration keepAliveInterval) {
6369
Assert.notNull(objectMapper, "ObjectMapper must not be null");
6470
Assert.notNull(mcpEndpoint, "Message endpoint must not be null");
6571
Assert.notNull(contextExtractor, "Context extractor must not be null");
@@ -73,6 +79,20 @@ private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Stri
7379
.POST(this.mcpEndpoint, this::handlePost)
7480
.DELETE(this.mcpEndpoint, this::handleDelete)
7581
.build();
82+
83+
if (keepAliveInterval != null) {
84+
this.keepAliveScheduler = KeepAliveScheduler
85+
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values()))
86+
.initialDelay(keepAliveInterval)
87+
.interval(keepAliveInterval)
88+
.build();
89+
90+
this.keepAliveScheduler.start();
91+
}
92+
else {
93+
logger.warn("Keep-alive interval is not set or invalid. No keep-alive will be scheduled.");
94+
}
95+
7696
}
7797

7898
@Override
@@ -105,6 +125,11 @@ public Mono<Void> closeGracefully() {
105125
.doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()))
106126
.flatMap(McpStreamableServerSession::closeGracefully)
107127
.then();
128+
}).then().doOnSuccess(v -> {
129+
sessions.clear();
130+
if (this.keepAliveScheduler != null) {
131+
this.keepAliveScheduler.shutdown();
132+
}
108133
});
109134
}
110135

@@ -368,6 +393,8 @@ public static class Builder {
368393

369394
private boolean disallowDelete;
370395

396+
private Duration keepAliveInterval;
397+
371398
private Builder() {
372399
// used by a static method
373400
}
@@ -424,6 +451,17 @@ public Builder disallowDelete(boolean disallowDelete) {
424451
return this;
425452
}
426453

454+
/**
455+
* Sets the keep-alive interval for the server transport.
456+
* @param keepAliveInterval The interval for sending keep-alive messages. If null,
457+
* no keep-alive will be scheduled.
458+
* @return this builder instance
459+
*/
460+
public Builder keepAliveInterval(Duration keepAliveInterval) {
461+
this.keepAliveInterval = keepAliveInterval;
462+
return this;
463+
}
464+
427465
/**
428466
* Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with
429467
* the configured settings.
@@ -435,7 +473,7 @@ public WebFluxStreamableServerTransportProvider build() {
435473
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
436474

437475
return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor,
438-
disallowDelete);
476+
disallowDelete, keepAliveInterval);
439477
}
440478

441479
}

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
3434
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
3535
import io.modelcontextprotocol.util.Assert;
36+
import io.modelcontextprotocol.util.KeepAliveScheduler;
37+
import reactor.core.publisher.Flux;
3638
import reactor.core.publisher.Mono;
3739

3840
/**
@@ -101,6 +103,8 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer
101103
*/
102104
private volatile boolean isClosing = false;
103105

106+
private KeepAliveScheduler keepAliveScheduler;
107+
104108
/**
105109
* Constructs a new WebMvcStreamableServerTransportProvider instance.
106110
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
@@ -113,7 +117,8 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer
113117
* @throws IllegalArgumentException if any parameter is null
114118
*/
115119
private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
116-
boolean disallowDelete, McpTransportContextExtractor<ServerRequest> contextExtractor) {
120+
boolean disallowDelete, McpTransportContextExtractor<ServerRequest> contextExtractor,
121+
Duration keepAliveInterval) {
117122
Assert.notNull(objectMapper, "ObjectMapper must not be null");
118123
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
119124
Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null");
@@ -127,6 +132,19 @@ private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, Strin
127132
.POST(this.mcpEndpoint, this::handlePost)
128133
.DELETE(this.mcpEndpoint, this::handleDelete)
129134
.build();
135+
136+
if (keepAliveInterval != null) {
137+
this.keepAliveScheduler = KeepAliveScheduler
138+
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values()))
139+
.initialDelay(keepAliveInterval)
140+
.interval(keepAliveInterval)
141+
.build();
142+
143+
this.keepAliveScheduler.start();
144+
}
145+
else {
146+
logger.warn("Keep-alive interval is not set or invalid. No keep-alive will be scheduled.");
147+
}
130148
}
131149

132150
@Override
@@ -184,6 +202,12 @@ public Mono<Void> closeGracefully() {
184202

185203
this.sessions.clear();
186204
logger.debug("Graceful shutdown completed");
205+
}).then().doOnSuccess(v -> {
206+
logger.debug("Graceful shutdown completed");
207+
sessions.clear();
208+
if (this.keepAliveScheduler != null) {
209+
this.keepAliveScheduler.shutdown();
210+
}
187211
});
188212
}
189213

@@ -584,6 +608,8 @@ public static class Builder {
584608

585609
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
586610

611+
private Duration keepAliveInterval;
612+
587613
/**
588614
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
589615
* messages.
@@ -635,6 +661,18 @@ public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> cont
635661
return this;
636662
}
637663

664+
/**
665+
* Sets the keep-alive interval for the transport. If set, a keep-alive scheduler
666+
* will be created to periodically check and send keep-alive messages to clients.
667+
* @param keepAliveInterval The interval duration for keep-alive messages, or null
668+
* to disable keep-alive
669+
* @return this builder instance
670+
*/
671+
public Builder keepAliveInterval(Duration keepAliveInterval) {
672+
this.keepAliveInterval = keepAliveInterval;
673+
return this;
674+
}
675+
638676
/**
639677
* Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with
640678
* the configured settings.
@@ -646,7 +684,7 @@ public WebMvcStreamableServerTransportProvider build() {
646684
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
647685

648686
return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete,
649-
this.contextExtractor);
687+
this.contextExtractor, this.keepAliveInterval);
650688
}
651689

652690
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
package io.modelcontextprotocol.spec;
22

3+
import java.time.Duration;
4+
import java.util.Map;
5+
import java.util.UUID;
6+
import java.util.concurrent.ConcurrentHashMap;
7+
import java.util.concurrent.atomic.AtomicLong;
8+
import java.util.concurrent.atomic.AtomicReference;
9+
import java.util.function.Supplier;
10+
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
314
import com.fasterxml.jackson.core.type.TypeReference;
15+
416
import io.modelcontextprotocol.server.McpAsyncServerExchange;
517
import io.modelcontextprotocol.server.McpNotificationHandler;
618
import io.modelcontextprotocol.server.McpRequestHandler;
719
import io.modelcontextprotocol.server.McpTransportContext;
820
import io.modelcontextprotocol.util.Assert;
9-
import org.slf4j.Logger;
10-
import org.slf4j.LoggerFactory;
1121
import reactor.core.publisher.Flux;
1222
import reactor.core.publisher.Mono;
1323
import reactor.core.publisher.MonoSink;
1424

15-
import java.time.Duration;
16-
import java.util.Map;
17-
import java.util.UUID;
18-
import java.util.concurrent.ConcurrentHashMap;
19-
import java.util.concurrent.atomic.AtomicInteger;
20-
import java.util.concurrent.atomic.AtomicLong;
21-
import java.util.concurrent.atomic.AtomicReference;
22-
import java.util.function.Supplier;
23-
2425
/**
2526
* Representation of a Streamable HTTP server session that keeps track of mapping
2627
* server-initiated requests to the client and mapping arriving responses. It also allows

0 commit comments

Comments
 (0)