Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -102,6 +105,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
/** Map of active client sessions, keyed by session ID */
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();

private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

/** Flag indicating if the transport is in the process of shutting down */
private final AtomicBoolean isClosing = new AtomicBoolean(false);

Expand Down Expand Up @@ -144,7 +149,7 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m
@Deprecated
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, (serverRequest, context) -> context);
}

/**
Expand All @@ -163,11 +168,38 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b
@Deprecated
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval,
(serverRequest, context) -> context);
}

/**
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
* endpoint.
* @param objectMapper The JSON object mapper to use for message
* serialization/deserialization
* @param baseUrl The base URL for the server transport
* @param messageEndpoint The endpoint path where clients will send their messages
* @param sseEndpoint The endpoint path where clients will establish SSE connections
* @param keepAliveInterval The interval for keep-alive pings, or null to disable
* keep-alive functionality
* @param contextExtractor The extractor for transport context from the request.
* @deprecated Use the builder {@link #builder()} instead for better configuration
* options.
*/
private HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval,
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {

Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(messageEndpoint, "messageEndpoint must not be null");
Assert.notNull(sseEndpoint, "sseEndpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.contextExtractor = contextExtractor;

if (keepAliveInterval != null) {

Expand Down Expand Up @@ -339,10 +371,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
body.append(line);
}

final McpTransportContext transportContext = this.contextExtractor.extract(request,
new DefaultMcpTransportContext());
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());

// Process the message through the session's handle method
session.handle(message).block(); // Block for Servlet compatibility
// Block for Servlet compatibility
session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();

response.setStatus(HttpServletResponse.SC_OK);
}
Expand Down Expand Up @@ -534,6 +569,8 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (serverRequest, context) -> context;

private Duration keepAliveInterval;

/**
Expand Down Expand Up @@ -583,6 +620,19 @@ public Builder sseEndpoint(String sseEndpoint) {
return this;
}

/**
* Sets the context extractor for extracting transport context from the request.
* @param contextExtractor The context extractor to use. Must not be null.
* @return this builder instance
* @throws IllegalArgumentException if contextExtractor is null
*/
public HttpServletSseServerTransportProvider.Builder contextExtractor(
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
Assert.notNull(contextExtractor, "Context extractor must not be null");
this.contextExtractor = contextExtractor;
return this;
}

/**
* Sets the interval for keep-alive pings.
* <p>
Expand All @@ -609,7 +659,7 @@ public HttpServletSseServerTransportProvider build() {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
keepAliveInterval);
keepAliveInterval, contextExtractor);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ public Mono<Void> sendNotification(String method, Object params) {
* @return a Mono that completes when the message is processed
*/
public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
return Mono.defer(() -> {
return Mono.deferContextual(ctx -> {
McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY);

// TODO handle errors for communication to without initialization happening
// first
if (message instanceof McpSchema.JSONRPCResponse response) {
Expand All @@ -214,7 +216,7 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
return handleIncomingRequest(request).onErrorResume(error -> {
return handleIncomingRequest(request, transportContext).onErrorResume(error -> {
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
error.getMessage(), null));
Expand All @@ -227,7 +229,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) {
// happening first
logger.debug("Received notification: {}", notification);
// TODO: in case of error, should the POST request be signalled?
return handleIncomingNotification(notification)
return handleIncomingNotification(notification, transportContext)
.doOnError(error -> logger.error("Error handling notification: {}", error.getMessage()));
}
else {
Expand All @@ -240,9 +242,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) {
/**
* Handles an incoming JSON-RPC request by routing it to the appropriate handler.
* @param request The incoming JSON-RPC request
* @param transportContext
* @return A Mono containing the JSON-RPC response
*/
private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request) {
private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request,
McpTransportContext transportContext) {
return Mono.defer(() -> {
Mono<?> resultMono;
if (McpSchema.METHOD_INITIALIZE.equals(request.method())) {
Expand All @@ -266,7 +270,17 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
error.message(), error.data())));
}

resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
resultMono = this.exchangeSink.asMono().flatMap(exchange -> {
// This legacy implementation assumes an exchange is established upon
// the initialization phase see: exchangeSink.tryEmitValue(...),
// which creates a cached immutable exchange.
// Here, we create a new exchange and copy over everything from that
// cached exchange, and use it for a single HTTP request, with the
// transport context passed in.
McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this,
exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
return handler.handle(newExchange, request.params());
});
}
return resultMono
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
Expand All @@ -280,24 +294,36 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
/**
* Handles an incoming JSON-RPC notification by routing it to the appropriate handler.
* @param notification The incoming JSON-RPC notification
* @param transportContext
* @return A Mono that completes when the notification is processed
*/
private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification) {
private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification,
McpTransportContext transportContext) {
return Mono.defer(() -> {
if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) {
this.state.lazySet(STATE_INITIALIZED);
// FIXME: The session ID passed here is not the same as the one in the
// legacy SSE transport.
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(),
clientInfo.get(), McpTransportContext.EMPTY));
clientInfo.get(), transportContext));
}

var handler = notificationHandlers.get(notification.method());
if (handler == null) {
logger.warn("No handler registered for notification method: {}", notification);
return Mono.empty();
}
return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params()));
return this.exchangeSink.asMono().flatMap(exchange -> {
// This legacy implementation assumes an exchange is established upon
// the initialization phase see: exchangeSink.tryEmitValue(...),
// which creates a cached immutable exchange.
// Here, we create a new exchange and copy over everything from that
// cached exchange, and use it for a single HTTP request, with the
// transport context passed in.
McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this,
exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
Comment on lines +317 to +324
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be extracted to a utility method to avoid the copy of the comment.

return handler.handle(newExchange, notification.params());
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertWith;
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.Mockito.mock;

import java.net.URI;
Expand All @@ -23,11 +24,14 @@
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import jakarta.servlet.http.HttpServletRequest;
import org.assertj.core.util.Strings;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

Expand Down Expand Up @@ -746,6 +750,7 @@ void testToolCallSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

var responseBodyIsNullOrBlank = new AtomicBoolean(false);
var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
.tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
Expand All @@ -759,7 +764,7 @@ void testToolCallSuccess(String clientType) {
.GET()
.build(), HttpResponse.BodyHandlers.ofString());
String responseBody = response.body();
assertThat(responseBody).isNotBlank();
responseBodyIsNullOrBlank.set(Strings.isNullOrEmpty(responseBody));
}
catch (Exception e) {
e.printStackTrace();
Expand All @@ -782,6 +787,7 @@ void testToolCallSuccess(String clientType) {

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertFalse(responseBodyIsNullOrBlank.get(), "Response body should not be blank");
assertThat(response).isNotNull().isEqualTo(callResponse);
}

Expand Down Expand Up @@ -825,6 +831,68 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) {
mcpServer.close();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testToolCallSuccessWithTranportContextExtraction(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

var transportContextIsNull = new AtomicBoolean(false);
var transportContextIsEmpty = new AtomicBoolean(false);
var responseBodyIsNullOrBlank = new AtomicBoolean(false);

var expectedCallResponse = new McpSchema.CallToolResult(
List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null);
McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
.tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
.callHandler((exchange, request) -> {

McpTransportContext transportContext = exchange.transportContext();
transportContextIsNull.set(transportContext == null);
transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY));
String ctxValue = (String) transportContext.get("important");

try {
HttpResponse<String> response = HttpClient.newHttpClient()
.send(HttpRequest.newBuilder()
.uri(URI.create(
"https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there is no need to call an external API in this test. I'd say we don't need to call anything, just return the contents of the context from the tool.

.GET()
.build(), HttpResponse.BodyHandlers.ofString());
String responseBody = response.body();
responseBodyIsNullOrBlank.set(Strings.isNullOrEmpty(responseBody));
}
catch (Exception e) {
e.printStackTrace();
}

return new McpSchema.CallToolResult(
List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null);
})
.build();

var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool1)
.build();

try (var mcpClient = clientBuilder.build()) {

InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

assertThat(mcpClient.listTools().tools()).contains(tool1.tool());

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertFalse(transportContextIsNull.get(), "transportContext should not be null");
assertFalse(transportContextIsEmpty.get(), "transportContext should not be empty");
assertFalse(responseBodyIsNullOrBlank.get(), "Response body should not be blank");
assertThat(response).isNotNull().isEqualTo(expectedCallResponse);
}

mcpServer.close();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testToolListChangeHandlingSuccess(String clientType) {
Expand Down Expand Up @@ -1531,4 +1599,9 @@ private double evaluateExpression(String expression) {
};
}

protected static McpTransportContextExtractor<HttpServletRequest> TEST_CONTEXT_EXTRACTOR = (r, tc) -> {
tc.put("important", "value");
return tc;
};
Comment on lines +1602 to +1605
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a copy of this in both HttpServletSseIntegrationTests and HttpServletStreamableIntegrationTests to avoid a reference to a servlet type in the base class.

BTW, I thought this abstract test can be used outside of the servlet context, but it seems the variant in mcp-test module is not really a copy.. @tzolov do we have a strategy for aligning these? There's no comment like we have in other files of this sort (// KEEP IN SYNC with the class in mcp-test module)...


}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public void before() {
// Create and configure the transport provider
mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public void before() {
// Create and configure the transport provider
mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
.mcpEndpoint(MESSAGE_ENDPOINT)
.keepAliveInterval(Duration.ofSeconds(1))
.build();
Expand Down
Loading