Skip to content

Commit 87a018f

Browse files
committed
WIP: context extractor usage, some javadocs, HttpHeaders constants
Signed-off-by: Dariusz Jędrzejczyk <[email protected]>
1 parent 0a4ecb0 commit 87a018f

23 files changed

+192
-261
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java

Lines changed: 54 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import com.fasterxml.jackson.databind.ObjectMapper;
44
import io.modelcontextprotocol.server.McpStatelessServerHandler;
5-
import io.modelcontextprotocol.spec.DefaultMcpTransportContext;
5+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
6+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
67
import io.modelcontextprotocol.spec.McpError;
78
import io.modelcontextprotocol.spec.McpSchema;
89
import io.modelcontextprotocol.spec.McpStatelessServerTransport;
9-
import io.modelcontextprotocol.spec.McpTransportContext;
10+
import io.modelcontextprotocol.server.McpTransportContext;
1011
import io.modelcontextprotocol.util.Assert;
1112
import org.slf4j.Logger;
1213
import org.slf4j.LoggerFactory;
@@ -19,50 +20,39 @@
1920
import reactor.core.publisher.Mono;
2021

2122
import java.io.IOException;
23+
import java.util.List;
2224
import java.util.function.Function;
2325

26+
/**
27+
* Implementation of a WebFlux based {@link McpStatelessServerTransport}.
28+
*
29+
* @author Dariusz Jędrzejczyk
30+
*/
2431
public class WebFluxStatelessServerTransport implements McpStatelessServerTransport {
2532

2633
private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class);
2734

28-
public static final String DEFAULT_BASE_URL = "";
29-
3035
private final ObjectMapper objectMapper;
3136

32-
private final String baseUrl;
33-
3437
private final String mcpEndpoint;
3538

3639
private final RouterFunction<?> routerFunction;
3740

3841
private McpStatelessServerHandler mcpHandler;
3942

40-
// TODO: add means to specify this
41-
private Function<ServerRequest, McpTransportContext> contextExtractor = req -> new DefaultMcpTransportContext();
43+
private McpTransportContextExtractor<ServerRequest> contextExtractor;
4244

43-
/**
44-
* Flag indicating if the transport is shutting down.
45-
*/
4645
private volatile boolean isClosing = false;
4746

48-
/**
49-
* Constructs a new WebFlux SSE server transport provider instance.
50-
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
51-
* of MCP messages. Must not be null.
52-
* @param baseUrl webflux message base path
53-
* @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
54-
* messages. This endpoint will be communicated to clients during SSE connection
55-
* setup. Must not be null.
56-
* @throws IllegalArgumentException if either parameter is null
57-
*/
58-
public WebFluxStatelessServerTransport(ObjectMapper objectMapper, String baseUrl, String mcpEndpoint) {
59-
Assert.notNull(objectMapper, "ObjectMapper must not be null");
60-
Assert.notNull(baseUrl, "Message base path must not be null");
61-
Assert.notNull(mcpEndpoint, "Message endpoint must not be null");
47+
private WebFluxStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint,
48+
McpTransportContextExtractor<ServerRequest> contextExtractor) {
49+
Assert.notNull(objectMapper, "objectMapper must not be null");
50+
Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null");
51+
Assert.notNull(contextExtractor, "contextExtractor must not be null");
6252

6353
this.objectMapper = objectMapper;
64-
this.baseUrl = baseUrl;
6554
this.mcpEndpoint = mcpEndpoint;
55+
this.contextExtractor = contextExtractor;
6656
this.routerFunction = RouterFunctions.route()
6757
.GET(this.mcpEndpoint, this::handleGet)
6858
.POST(this.mcpEndpoint, this::handlePost)
@@ -74,75 +64,43 @@ public void setMcpHandler(McpStatelessServerHandler mcpHandler) {
7464
this.mcpHandler = mcpHandler;
7565
}
7666

77-
// FIXME: This javadoc makes claims about using isClosing flag but it's not
78-
// actually
79-
// doing that.
80-
/**
81-
* Initiates a graceful shutdown of all the sessions. This method ensures all active
82-
* sessions are properly closed and cleaned up.
83-
*
84-
* <p>
85-
* The shutdown process:
86-
* <ul>
87-
* <li>Marks the transport as closing to prevent new connections</li>
88-
* <li>Closes each active session</li>
89-
* <li>Removes closed sessions from the sessions map</li>
90-
* <li>Times out after 5 seconds if shutdown takes too long</li>
91-
* </ul>
92-
* @return A Mono that completes when all sessions have been closed
93-
*/
9467
@Override
9568
public Mono<Void> closeGracefully() {
96-
return Mono.empty();
69+
return Mono.fromRunnable(() -> this.isClosing = true);
9770
}
9871

9972
/**
10073
* Returns the WebFlux router function that defines the transport's HTTP endpoints.
10174
* This router function should be integrated into the application's web configuration.
10275
*
10376
* <p>
104-
* The router function defines two endpoints:
77+
* The router function defines one endpoint handling two HTTP methods:
10578
* <ul>
106-
* <li>GET {sseEndpoint} - For establishing SSE connections</li>
107-
* <li>POST {messageEndpoint} - For receiving client messages</li>
79+
* <li>GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED</li>
80+
* <li>POST {messageEndpoint} - For handling client requests and notifications</li>
10881
* </ul>
10982
* @return The configured {@link RouterFunction} for handling HTTP requests
11083
*/
11184
public RouterFunction<?> getRouterFunction() {
11285
return this.routerFunction;
11386
}
11487

115-
/**
116-
* Handles GET requests from clients.
117-
* @param request The incoming server request
118-
* @return A Mono which emits a response informing the client that listening stream is
119-
* unavailable
120-
*/
12188
private Mono<ServerResponse> handleGet(ServerRequest request) {
12289
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
12390
}
12491

125-
/**
126-
* Handles incoming JSON-RPC messages from clients. Deserializes the message and
127-
* processes it through the configured message handler.
128-
*
129-
* <p>
130-
* The handler:
131-
* <ul>
132-
* <li>Deserializes the incoming JSON-RPC message</li>
133-
* <li>Passes it through the message handler chain</li>
134-
* <li>Returns appropriate HTTP responses based on processing results</li>
135-
* <li>Handles various error conditions with appropriate error responses</li>
136-
* </ul>
137-
* @param request The incoming server request containing the JSON-RPC message
138-
* @return A Mono emitting the response indicating the message processing result
139-
*/
14092
private Mono<ServerResponse> handlePost(ServerRequest request) {
14193
if (isClosing) {
14294
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
14395
}
14496

145-
McpTransportContext transportContext = this.contextExtractor.apply(request);
97+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
98+
99+
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
100+
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
101+
&& acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) {
102+
return ServerResponse.badRequest().build();
103+
}
146104

147105
return request.bodyToMono(String.class).<ServerResponse>flatMap(body -> {
148106
try {
@@ -170,6 +128,10 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
170128
}).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
171129
}
172130

131+
/**
132+
* Create a builder for the server.
133+
* @return a fresh {@link Builder} instance.
134+
*/
173135
public static Builder builder() {
174136
return new Builder();
175137
}
@@ -184,10 +146,14 @@ public static class Builder {
184146

185147
private ObjectMapper objectMapper;
186148

187-
private String baseUrl = DEFAULT_BASE_URL;
188-
189149
private String mcpEndpoint = "/mcp";
190150

151+
private McpTransportContextExtractor<ServerRequest> contextExtractor = (serverRequest, context) -> context;
152+
153+
private Builder() {
154+
// used by a static method
155+
}
156+
191157
/**
192158
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
193159
* messages.
@@ -201,19 +167,6 @@ public Builder objectMapper(ObjectMapper objectMapper) {
201167
return this;
202168
}
203169

204-
/**
205-
* Sets the project basePath as endpoint prefix where clients should send their
206-
* JSON-RPC messages
207-
* @param baseUrl the message basePath . Must not be null.
208-
* @return this builder instance
209-
* @throws IllegalArgumentException if basePath is null
210-
*/
211-
public Builder basePath(String baseUrl) {
212-
Assert.notNull(baseUrl, "basePath must not be null");
213-
this.baseUrl = baseUrl;
214-
return this;
215-
}
216-
217170
/**
218171
* Sets the endpoint URI where clients should send their JSON-RPC messages.
219172
* @param messageEndpoint The message endpoint URI. Must not be null.
@@ -226,6 +179,22 @@ public Builder messageEndpoint(String messageEndpoint) {
226179
return this;
227180
}
228181

182+
/**
183+
* Sets the context extractor that allows providing the MCP feature
184+
* implementations to inspect HTTP transport level metadata that was present at
185+
* HTTP request processing time. This allows to extract custom headers and other
186+
* useful data for use during execution later on in the process.
187+
* @param contextExtractor The contextExtractor to fill in a
188+
* {@link McpTransportContext}.
189+
* @return this builder instance
190+
* @throws IllegalArgumentException if contextExtractor is null
191+
*/
192+
public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
193+
Assert.notNull(contextExtractor, "Context extractor must not be null");
194+
this.contextExtractor = contextExtractor;
195+
return this;
196+
}
197+
229198
/**
230199
* Builds a new instance of {@link WebFluxStatelessServerTransport} with the
231200
* configured settings.
@@ -236,7 +205,7 @@ public WebFluxStatelessServerTransport build() {
236205
Assert.notNull(objectMapper, "ObjectMapper must be set");
237206
Assert.notNull(mcpEndpoint, "Message endpoint must be set");
238207

239-
return new WebFluxStatelessServerTransport(objectMapper, baseUrl, mcpEndpoint);
208+
return new WebFluxStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor);
240209
}
241210

242211
}

0 commit comments

Comments
 (0)