From d7efe414bca9b9087a266b911987577c1dd4db2f Mon Sep 17 00:00:00 2001 From: Jose Flores Date: Thu, 20 Feb 2025 19:17:20 -0600 Subject: [PATCH] Process JSONRPCRequest with default param (#42) --- .../kotlin/sdk/shared/Protocol.kt | 2 +- .../modelcontextprotocol/kotlin/sdk/types.kt | 2 +- src/jvmTest/kotlin/client/ClientTest.kt | 76 +++++++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index a067d533..70889531 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -440,7 +440,7 @@ public abstract class Protocol( val serializer = McpJson.serializersModule.serializer(requestType) requestHandlers[method.value] = { request, extraHandler -> - val result = request.params?.let { McpJson.decodeFromJsonElement(serializer, it) } + val result = McpJson.decodeFromJsonElement(serializer, request.params) val response = if (result != null) { @Suppress("UNCHECKED_CAST") block(result as T, extraHandler) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index 5343bef8..d2fa2cb2 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -215,7 +215,7 @@ public sealed interface JSONRPCMessage public data class JSONRPCRequest( val id: RequestId = RequestId.NumberId(REQUEST_MESSAGE_ID.incrementAndGet()), val method: String, - val params: JsonElement? = null, + val params: JsonElement = EmptyJsonObject, val jsonrpc: String = JSONRPC_VERSION, ) : JSONRPCMessage diff --git a/src/jvmTest/kotlin/client/ClientTest.kt b/src/jvmTest/kotlin/client/ClientTest.kt index 5ffc173a..ef60fb9f 100644 --- a/src/jvmTest/kotlin/client/ClientTest.kt +++ b/src/jvmTest/kotlin/client/ClientTest.kt @@ -24,6 +24,7 @@ import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.cancel @@ -38,6 +39,7 @@ import org.junit.jupiter.api.Test import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import org.junit.jupiter.api.assertInstanceOf import kotlin.coroutines.cancellation.CancellationException import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -494,5 +496,79 @@ class ClientTest { } } + @Test + fun `JSONRPCRequest with ToolsList method and default params returns list of tools`() = runTest { + val serverOptions = ServerOptions( + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(null) + ) + ) + val server = Server( + Implementation(name = "test server", version = "1.0"), + serverOptions + ) + + server.setRequestHandler(Method.Defined.Initialize) { request, _ -> + InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(null, null), + tools = ServerCapabilities.Tools(null) + ), + serverInfo = Implementation(name = "test", version = "1.0") + ) + } + val serverListToolsResult = ListToolsResult( + tools = listOf( + Tool( + name = "testTool", + description = "testTool description", + inputSchema = Tool.Input() + ) + ), nextCursor = null + ) + + server.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + serverListToolsResult + } + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(sampling = EmptyJsonObject), + ) + ) + + var receivedMessage: JSONRPCMessage? = null + clientTransport.onMessage { msg -> + receivedMessage = msg + } + + listOf( + launch { + client.connect(clientTransport) + }, + launch { + server.connect(serverTransport) + } + ).joinAll() + + val serverCapabilities = client.serverCapabilities + assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) + + val request = JSONRPCRequest( + method = Method.Defined.ToolsList.value + ) + clientTransport.send(request) + + assertInstanceOf(receivedMessage) + val receivedAsResponse = receivedMessage as JSONRPCResponse + assertEquals(request.id, receivedAsResponse.id) + assertEquals(request.jsonrpc, receivedAsResponse.jsonrpc) + assertEquals(serverListToolsResult, receivedAsResponse.result) + assertEquals(null, receivedAsResponse.error) + } }