diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index 9eba0e57c..9f7021938 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -4,29 +4,39 @@ package io.modelcontextprotocol; -import java.time.Duration; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; - import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; +import java.time.Duration; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + @Timeout(15) class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -88,4 +98,53 @@ public void after() { } } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallThrowMcpError(String clientType) { + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("toolThrowMcpError") + .description("toolThrowMcpError description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + throw new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError(50000, "test exception message", Map.of("a", "b"))); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThatThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("toolThrowMcpError", Map.of()))) + .isInstanceOf(McpError.class) + .hasMessage("test exception message") + .satisfies(ex -> { + McpError mcpError = (McpError) ex; + assertThat(mcpError.getJsonRpcError()).isNotNull(); + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(50000); + }); + + } + + mcpServer.close(); + } + } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index 16012e7d9..a0367ca27 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -3,32 +3,39 @@ */ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; - -import java.time.Duration; - +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; +import reactor.core.scheduler.Schedulers; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.time.Duration; +import java.util.Map; -import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SyncSpecification; -import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; -import reactor.core.scheduler.Schedulers; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; @Timeout(15) class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -139,4 +146,53 @@ protected void prepareClients(int port, String mcpEndpoint) { .requestTimeout(Duration.ofHours(10))); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallThrowMcpError(String clientType) { + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("toolThrowMcpError") + .description("toolThrowMcpError description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + throw new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError(50000, "test exception message", Map.of("a", "b"))); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThatThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("toolThrowMcpError", Map.of()))) + .isInstanceOf(McpError.class) + .hasMessage("test exception message") + .satisfies(ex -> { + McpError mcpError = (McpError) ex; + assertThat(mcpError.getJsonRpcError()).isNotNull(); + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(50000); + }); + + } + + mcpServer.close(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index ef7967c1e..098906c0d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -177,9 +177,16 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) .onErrorResume(e -> { + McpSchema.JSONRPCResponse.JSONRPCError error; + if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) { + error = mcpError.getJsonRpcError(); + } + else { + error = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), null); + } var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - e.getMessage(), null)); + null, error); return Mono.just(errorResponse); }) .flatMap(transport::sendMessage)