From a9f524a7cb79f6ea78bd5a1678022f05674fc574 Mon Sep 17 00:00:00 2001 From: Alexander Sysoev Date: Mon, 20 Jan 2025 18:51:06 +0100 Subject: [PATCH 1/3] Fixed compilation --- .../sdk/client/{sse.ktor.kt => KtorClient.kt} | 2 +- .../kotlin/sdk/client/SSEClientTransport.kt | 37 +++---------------- .../kotlin/sdk/client/StdioClientTransport.kt | 20 ++++------ .../kotlin/sdk/server/SSEServerTransport.kt | 25 ++++++------- .../kotlin/sdk/server/StdioServerTransport.kt | 19 ++++------ .../kotlin/sdk/shared/Protocol.kt | 8 ++-- .../kotlin/sdk/shared/Transport.kt | 35 ++++++++++++++++++ .../sdk/shared/WebSocketMcpTransport.kt | 16 +++----- src/jvmTest/kotlin/InMemoryTransport.kt | 17 +++------ .../kotlin/client/BaseTransportTest.kt | 8 ++-- src/jvmTest/kotlin/client/ClientTest.kt | 26 ++++--------- .../kotlin/client/InMemoryTransportTest.kt | 12 +++--- src/jvmTest/kotlin/client/SseTransportTest.kt | 31 +++++++++++++--- .../kotlin/client/WebSocketTransportTest.kt | 2 +- .../kotlin/server/StdioServerTransportTest.kt | 14 +++---- 15 files changed, 135 insertions(+), 137 deletions(-) rename src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/{sse.ktor.kt => KtorClient.kt} (96%) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/sse.ktor.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt similarity index 96% rename from src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/sse.ktor.kt rename to src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt index 4c36309f..2ccc223d 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/sse.ktor.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt @@ -19,7 +19,7 @@ public fun HttpClient.mcpSseTransport( urlString: String? = null, reconnectionTime: Duration? = null, requestBuilder: HttpRequestBuilder.() -> Unit = {}, -): SSEClientTransport = SSEClientTransport(this, urlString, reconnectionTime, requestBuilder) +): SseClientTransport = SseClientTransport(this, urlString, reconnectionTime, requestBuilder) /** * Creates and connects an MCP client over SSE using the provided HttpClient. diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index 5a7a4f2f..9e213477 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -6,8 +6,8 @@ import io.ktor.client.request.* import io.ktor.client.statement.* import io.ktor.http.* import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import io.modelcontextprotocol.kotlin.sdk.shared.Transport import kotlinx.atomicfu.AtomicBoolean import kotlinx.atomicfu.atomic import kotlinx.coroutines.* @@ -15,16 +15,19 @@ import kotlinx.serialization.encodeToString import kotlin.properties.Delegates import kotlin.time.Duration +@Deprecated("Use SseClientTransport instead", ReplaceWith("SseClientTransport"), DeprecationLevel.WARNING) +public typealias SSEClientTransport = SseClientTransport + /** * Client transport for SSE: this will connect to a server using Server-Sent Events for receiving * messages and make separate POST requests for sending messages. */ -public class SSEClientTransport( +public class SseClientTransport( private val client: HttpClient, private val urlString: String?, private val reconnectionTime: Duration? = null, private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, -) : Transport { +) : AbstractTransport() { private val scope by lazy { CoroutineScope(session.coroutineContext + SupervisorJob()) } @@ -33,10 +36,6 @@ public class SSEClientTransport( private var session: ClientSSESession by Delegates.notNull() private val endpoint = CompletableDeferred() - private var _onClose: (() -> Unit) = {} - private var _onError: ((Throwable) -> Unit) = {} - private var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {} - private var job: Job? = null private val baseUrl by lazy { @@ -136,28 +135,4 @@ public class SSEClientTransport( _onClose() job?.cancelAndJoin() } - - override fun onClose(block: () -> Unit) { - val old = _onClose - _onClose = { - old() - block() - } - } - - override fun onError(block: (Throwable) -> Unit) { - val old = _onError - _onError = { e -> - old(e) - block(e) - } - } - - override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - val old = _onMessage - _onMessage = { message -> - old(message) - block(message) - } - } } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 788f4e3e..f579f92a 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -2,8 +2,8 @@ package io.modelcontextprotocol.kotlin.sdk.client import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer -import io.modelcontextprotocol.kotlin.sdk.shared.Transport import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import kotlinx.atomicfu.AtomicBoolean import kotlinx.atomicfu.atomic @@ -30,7 +30,7 @@ import kotlin.coroutines.CoroutineContext public class StdioClientTransport( private val input: Source, private val output: Sink -) : Transport { +) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val ioCoroutineContext: CoroutineContext = Dispatchers.IO private val scope by lazy { @@ -41,10 +41,6 @@ public class StdioClientTransport( private val sendChannel = Channel(Channel.UNLIMITED) private val readBuffer = ReadBuffer() - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend ((JSONRPCMessage) -> Unit))? = null - override suspend fun start() { if (!initialized.compareAndSet(false, true)) { error("StdioClientTransport already started!") @@ -70,7 +66,7 @@ public class StdioClientTransport( } } } catch (e: Exception) { - onError?.invoke(e) + _onError.invoke(e) logger.error(e) { "Error reading from input stream" } } } @@ -85,7 +81,7 @@ public class StdioClientTransport( } } catch (e: Throwable) { if (isActive) { - onError?.invoke(e) + _onError.invoke(e) logger.error(e) { "Error writing to output stream" } } } finally { @@ -95,7 +91,7 @@ public class StdioClientTransport( readJob.join() writeJob.cancelAndJoin() - onClose?.invoke() + _onClose.invoke() } } @@ -116,16 +112,16 @@ public class StdioClientTransport( output.close() readBuffer.clear() sendChannel.close() - onClose?.invoke() + _onClose.invoke() } private suspend fun processReadBuffer() { while (true) { val msg = readBuffer.readMessage() ?: break try { - onMessage?.invoke(msg) + _onMessage.invoke(msg) } catch (e: Throwable) { - onError?.invoke(e) + _onError.invoke(e) logger.error(e) { "Error processing message." } } } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index 1902720e..32039cde 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -6,8 +6,8 @@ import io.ktor.server.request.* import io.ktor.server.response.* import io.ktor.server.sse.* import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import io.modelcontextprotocol.kotlin.sdk.shared.Transport import kotlinx.atomicfu.AtomicBoolean import kotlinx.atomicfu.atomic import kotlinx.coroutines.job @@ -17,24 +17,23 @@ import kotlin.uuid.Uuid internal const val SESSION_ID_PARAM = "sessionId" +@Deprecated("Use SseServerTransport instead", ReplaceWith("SseServerTransport"), DeprecationLevel.WARNING) +public typealias SSEServerTransport = SseServerTransport + /** * Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests. * * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. */ -public class SSEServerTransport( +public class SseServerTransport( private val endpoint: String, private val session: ServerSSESession, -) : Transport { +) : AbstractTransport() { private val initialized: AtomicBoolean = atomic(false) @OptIn(ExperimentalUuidApi::class) public val sessionId: String = Uuid.random().toString() - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend ((JSONRPCMessage) -> Unit))? = null - /** * Handles the initial SSE connection request. * @@ -54,7 +53,7 @@ public class SSEServerTransport( try { session.coroutineContext.job.join() } finally { - onClose?.invoke() + _onClose.invoke() } } @@ -67,7 +66,7 @@ public class SSEServerTransport( if (!initialized.value) { val message = "SSE connection not established" call.respondText(message, status = HttpStatusCode.InternalServerError) - onError?.invoke(IllegalStateException(message)) + _onError.invoke(IllegalStateException(message)) } val body = try { @@ -79,7 +78,7 @@ public class SSEServerTransport( call.receiveText() } catch (e: Exception) { call.respondText("Invalid message: ${e.message}", status = HttpStatusCode.BadRequest) - onError?.invoke(e) + _onError.invoke(e) return } @@ -100,16 +99,16 @@ public class SSEServerTransport( public suspend fun handleMessage(message: String) { try { val parsedMessage = McpJson.decodeFromString(message) - onMessage?.invoke(parsedMessage) + _onMessage.invoke(parsedMessage) } catch (e: Exception) { - onError?.invoke(e) + _onError.invoke(e) throw e } } override suspend fun close() { session.close() - onClose?.invoke() + _onClose.invoke() } override suspend fun send(message: JSONRPCMessage) { diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index b91b6cfe..d09d160d 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -2,8 +2,8 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer -import io.modelcontextprotocol.kotlin.sdk.shared.Transport import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import kotlinx.atomicfu.AtomicBoolean import kotlinx.atomicfu.atomic @@ -27,11 +27,8 @@ import kotlin.coroutines.CoroutineContext public class StdioServerTransport( private val inputStream: Source, //BufferedInputStream = BufferedInputStream(System.`in`), outputStream: Sink //PrintStream = System.out -) : Transport { +) : AbstractTransport() { private val logger = KotlinLogging.logger {} - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend (JSONRPCMessage) -> Unit)? = null private val readBuffer = ReadBuffer() private val initialized: AtomicBoolean = atomic(false) @@ -65,7 +62,7 @@ public class StdioServerTransport( } } catch (e: Throwable) { logger.error(e) { "Error reading from stdin" } - onError?.invoke(e) + _onError.invoke(e) } finally { // Reached EOF or error, close connection close() @@ -80,7 +77,7 @@ public class StdioServerTransport( processReadBuffer() } } catch (e: Throwable) { - onError?.invoke(e) + _onError.invoke(e) } } } @@ -90,16 +87,16 @@ public class StdioServerTransport( val message = try { readBuffer.readMessage() } catch (e: Throwable) { - onError?.invoke(e) + _onError.invoke(e) null } if (message == null) break // Async invocation broke delivery order try { - onMessage?.invoke(message) + _onMessage.invoke(message) } catch (e: Throwable) { - onError?.invoke(e) + _onError.invoke(e) } } } @@ -112,7 +109,7 @@ public class StdioServerTransport( readChannel.close() readBuffer.clear() - onClose?.invoke() + _onClose.invoke() } override suspend fun send(message: JSONRPCMessage) { diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 9716bc96..a067d533 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -152,15 +152,15 @@ public abstract class Protocol( */ public open suspend fun connect(transport: Transport) { this.transport = transport - transport.onClose = { + transport.onClose { doClose() } - transport.onError = { + transport.onError { onError(it) } - transport.onMessage = { message -> + transport.onMessage { message -> when (message) { is JSONRPCResponse -> onResponse(message, null) is JSONRPCRequest -> onRequest(message) @@ -477,4 +477,4 @@ public abstract class Protocol( public fun removeNotificationHandler(method: Method) { notificationHandlers.remove(method.value) } -} \ No newline at end of file +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt index b1b15a4e..96f531e7 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt @@ -46,3 +46,38 @@ public interface Transport { */ public fun onMessage(block: suspend (JSONRPCMessage) -> Unit) } + +/** + * Implements [onClose], [onError] and [onMessage] functions of [Transport] providing + * corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation. + */ +@Suppress("PropertyName") +public abstract class AbstractTransport : Transport { + protected var _onClose: (() -> Unit) = {} + protected var _onError: ((Throwable) -> Unit) = {} + protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {} + + override fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + override fun onError(block: (Throwable) -> Unit) { + val old = _onError + _onError = { e -> + old(e) + block(e) + } + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + val old = _onMessage + _onMessage = { message -> + old(message) + block(message) + } + } +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index f1cac920..cbba7286 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -22,7 +22,7 @@ internal const val MCP_SUBPROTOCOL = "mcp" * Abstract class representing a WebSocket transport for the Model Context Protocol (MCP). * Handles communication over a WebSocket session. */ -public abstract class WebSocketMcpTransport : Transport { +public abstract class WebSocketMcpTransport : AbstractTransport() { private val scope by lazy { CoroutineScope(session.coroutineContext + SupervisorJob()) } @@ -33,10 +33,6 @@ public abstract class WebSocketMcpTransport : Transport { */ protected abstract val session: WebSocketSession - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend ((JSONRPCMessage) -> Unit))? = null - /** * Initializes the WebSocket session */ @@ -62,15 +58,15 @@ public abstract class WebSocketMcpTransport : Transport { if (message !is Frame.Text) { val e = IllegalArgumentException("Expected text frame, got ${message::class.simpleName}: $message") - onError?.invoke(e) + _onError.invoke(e) throw e } try { val message = McpJson.decodeFromString(message.readText()) - onMessage?.invoke(message) + _onMessage.invoke(message) } catch (e: Exception) { - onError?.invoke(e) + _onError.invoke(e) throw e } } @@ -79,9 +75,9 @@ public abstract class WebSocketMcpTransport : Transport { @OptIn(InternalCoroutinesApi::class) session.coroutineContext.job.invokeOnCompletion { if (it != null) { - onError?.invoke(it) + _onError.invoke(it) } else { - onClose?.invoke() + _onClose.invoke() } } } diff --git a/src/jvmTest/kotlin/InMemoryTransport.kt b/src/jvmTest/kotlin/InMemoryTransport.kt index c68140d3..987d166e 100644 --- a/src/jvmTest/kotlin/InMemoryTransport.kt +++ b/src/jvmTest/kotlin/InMemoryTransport.kt @@ -1,17 +1,14 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.Transport /** * In-memory transport for creating clients and servers that talk to each other within the same process. */ -class InMemoryTransport : Transport { +class InMemoryTransport : AbstractTransport() { private var otherTransport: InMemoryTransport? = null private val messageQueue: MutableList = mutableListOf() - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend ((JSONRPCMessage) -> Unit))? = null - /** * Creates a pair of linked in-memory transports that can communicate with each other. * One should be passed to a Client and one to a Server. @@ -30,7 +27,7 @@ class InMemoryTransport : Transport { // Process any messages that were queued before start was called while (messageQueue.isNotEmpty()) { messageQueue.removeFirstOrNull()?.let { message -> - onMessage?.invoke(message) // todo? + _onMessage.invoke(message) // todo? } } } @@ -39,16 +36,12 @@ class InMemoryTransport : Transport { val other = otherTransport otherTransport = null other?.close() - onClose?.invoke() + _onClose.invoke() } override suspend fun send(message: JSONRPCMessage) { val other = otherTransport ?: throw IllegalStateException("Not connected") - if (other.onMessage != null) { - other.onMessage?.invoke(message) // todo? - } else { - other.messageQueue.add(message) - } + other._onMessage.invoke(message) } } diff --git a/src/jvmTest/kotlin/client/BaseTransportTest.kt b/src/jvmTest/kotlin/client/BaseTransportTest.kt index ce01b703..1bd27f63 100644 --- a/src/jvmTest/kotlin/client/BaseTransportTest.kt +++ b/src/jvmTest/kotlin/client/BaseTransportTest.kt @@ -13,12 +13,12 @@ import kotlin.test.fail abstract class BaseTransportTest { protected suspend fun testClientOpenClose(client: Transport) { - client.onError = { error -> + client.onError { error -> fail("Unexpected error: $error") } var didClose = false - client.onClose = { didClose = true } + client.onClose { didClose = true } client.start() assertFalse(didClose, "Transport should not be closed immediately after start") @@ -28,7 +28,7 @@ abstract class BaseTransportTest { } protected suspend fun testClientRead(client: Transport) { - client.onError = { error -> + client.onError { error -> error.printStackTrace() fail("Unexpected error: $error") } @@ -40,7 +40,7 @@ abstract class BaseTransportTest { val readMessages = mutableListOf() val finished = CompletableDeferred() - client.onMessage = { message -> + client.onMessage { message -> readMessages.add(message) if (message == messages.last()) { finished.complete(Unit) diff --git a/src/jvmTest/kotlin/client/ClientTest.kt b/src/jvmTest/kotlin/client/ClientTest.kt index 38c52ca8..5ffc173a 100644 --- a/src/jvmTest/kotlin/client/ClientTest.kt +++ b/src/jvmTest/kotlin/client/ClientTest.kt @@ -37,7 +37,7 @@ import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions import org.junit.jupiter.api.Test import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions -import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import kotlin.coroutines.cancellation.CancellationException import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -48,7 +48,7 @@ class ClientTest { @Test fun `should initialize with matching protocol version`() = runTest { var initialied = false - val clientTransport = object : Transport { + val clientTransport = object : AbstractTransport() { override suspend fun start() {} override suspend fun send(message: JSONRPCMessage) { @@ -68,15 +68,11 @@ class ClientTest { result = result ) - onMessage?.invoke(response) + _onMessage.invoke(response) } override suspend fun close() { } - - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend (JSONRPCMessage) -> Unit)? = null } val client = Client( @@ -98,7 +94,7 @@ class ClientTest { @Test fun `should initialize with supported older protocol version`() = runTest { val OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1] - val clientTransport = object : Transport { + val clientTransport = object : AbstractTransport() { override suspend fun start() {} override suspend fun send(message: JSONRPCMessage) { @@ -118,15 +114,11 @@ class ClientTest { id = message.id, result = result ) - onMessage?.invoke(response) + _onMessage.invoke(response) } override suspend fun close() { } - - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend (JSONRPCMessage) -> Unit)? = null } val client = Client( @@ -151,7 +143,7 @@ class ClientTest { @Test fun `should reject unsupported protocol version`() = runTest { var closed = false - val clientTransport = object : Transport { + val clientTransport = object : AbstractTransport() { override suspend fun start() {} override suspend fun send(message: JSONRPCMessage) { @@ -172,16 +164,12 @@ class ClientTest { result = result ) - onMessage?.invoke(response) + _onMessage.invoke(response) } override suspend fun close() { closed = true } - - override var onClose: (() -> Unit)? = null - override var onError: ((Throwable) -> Unit)? = null - override var onMessage: (suspend (JSONRPCMessage) -> Unit)? = null } val client = Client( diff --git a/src/jvmTest/kotlin/client/InMemoryTransportTest.kt b/src/jvmTest/kotlin/client/InMemoryTransportTest.kt index 41409521..90acd1ec 100644 --- a/src/jvmTest/kotlin/client/InMemoryTransportTest.kt +++ b/src/jvmTest/kotlin/client/InMemoryTransportTest.kt @@ -44,7 +44,7 @@ class InMemoryTransportTest { val message = InitializedNotification() var receivedMessage: JSONRPCMessage? = null - serverTransport.onMessage = { msg -> + serverTransport.onMessage { msg -> receivedMessage = msg } @@ -61,7 +61,7 @@ class InMemoryTransportTest { .toJSON() var receivedMessage: JSONRPCMessage? = null - clientTransport.onMessage = { msg -> + clientTransport.onMessage { msg -> receivedMessage = msg } @@ -76,11 +76,11 @@ class InMemoryTransportTest { var clientClosed = false var serverClosed = false - clientTransport.onClose = { + clientTransport.onClose { clientClosed = true } - serverTransport.onClose = { + serverTransport.onClose { serverClosed = true } @@ -112,7 +112,7 @@ class InMemoryTransportTest { .toJSON() var receivedMessage: JSONRPCMessage? = null - serverTransport.onMessage = { msg -> + serverTransport.onMessage { msg -> receivedMessage = msg } @@ -121,4 +121,4 @@ class InMemoryTransportTest { assertEquals(message, receivedMessage) } } -} \ No newline at end of file +} diff --git a/src/jvmTest/kotlin/client/SseTransportTest.kt b/src/jvmTest/kotlin/client/SseTransportTest.kt index 5099cbd2..b056893c 100644 --- a/src/jvmTest/kotlin/client/SseTransportTest.kt +++ b/src/jvmTest/kotlin/client/SseTransportTest.kt @@ -6,10 +6,13 @@ import io.ktor.server.application.* import io.ktor.server.cio.* import io.ktor.server.engine.* import io.ktor.server.routing.* +import io.ktor.server.sse.sse +import io.ktor.util.collections.ConcurrentMap import kotlinx.coroutines.test.runTest -import mcpSse -import mcpSseTransport import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport +import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport +import io.modelcontextprotocol.kotlin.sdk.server.mcpPostEndpoint +import io.modelcontextprotocol.kotlin.sdk.server.mcpSseTransport import org.junit.jupiter.api.Test private const val PORT = 8080 @@ -19,8 +22,15 @@ class SseTransportTest : BaseTransportTest() { fun `should start then close cleanly`() = runTest { val server = embeddedServer(CIO, port = PORT) { install(io.ktor.server.sse.SSE) + val transports = ConcurrentMap() routing { - mcpSse() + sse { + mcpSseTransport("", transports).start() + } + + post { + mcpPostEndpoint(transports) + } } }.start(wait = false) @@ -42,12 +52,21 @@ class SseTransportTest : BaseTransportTest() { fun `should read messages`() = runTest { val server = embeddedServer(CIO, port = PORT) { install(io.ktor.server.sse.SSE) + val transports = ConcurrentMap() routing { - mcpSseTransport { - onMessage = { - send(it) + sse { + mcpSseTransport("", transports).apply { + onMessage { + send(it) + } + + start() } } + + post { + mcpPostEndpoint(transports) + } } }.start(wait = false) diff --git a/src/jvmTest/kotlin/client/WebSocketTransportTest.kt b/src/jvmTest/kotlin/client/WebSocketTransportTest.kt index c680c3f0..93f39591 100644 --- a/src/jvmTest/kotlin/client/WebSocketTransportTest.kt +++ b/src/jvmTest/kotlin/client/WebSocketTransportTest.kt @@ -33,7 +33,7 @@ class WebSocketTransportTest : BaseTransportTest() { install(WebSockets) routing { mcpWebSocketTransport { - onMessage = { + onMessage { send(it) } diff --git a/src/jvmTest/kotlin/server/StdioServerTransportTest.kt b/src/jvmTest/kotlin/server/StdioServerTransportTest.kt index 724965b8..521e0509 100644 --- a/src/jvmTest/kotlin/server/StdioServerTransportTest.kt +++ b/src/jvmTest/kotlin/server/StdioServerTransportTest.kt @@ -54,12 +54,12 @@ class StdioServerTransportTest { fun `should start then close cleanly`() { runBlocking { val server = StdioServerTransport(bufferedInput, printOutput) - server.onError = { error -> + server.onError { error -> throw error } var didClose = false - server.onClose = { + server.onClose { didClose = true } @@ -75,14 +75,14 @@ class StdioServerTransportTest { fun `should not read until started`() { runBlocking { val server = StdioServerTransport(bufferedInput, printOutput) - server.onError = { error -> + server.onError { error -> throw error } var didRead = false val readMessage = CompletableDeferred() - server.onMessage = { message -> + server.onMessage { message -> didRead = true readMessage.complete(message) } @@ -106,7 +106,7 @@ class StdioServerTransportTest { fun `should read multiple messages`() { runBlocking { val server = StdioServerTransport(bufferedInput, printOutput) - server.onError = { error -> + server.onError { error -> throw error } @@ -118,7 +118,7 @@ class StdioServerTransportTest { val readMessages = mutableListOf() val finished = CompletableDeferred() - server.onMessage = { message -> + server.onMessage { message -> readMessages.add(message) if (message == messages[1]) { finished.complete(Unit) @@ -141,4 +141,4 @@ class StdioServerTransportTest { fun PipedOutputStream.write(s: String) { write(s.toByteArray()) -} \ No newline at end of file +} From 70a11d479049fb0900a1fe26059c7a108a507a65 Mon Sep 17 00:00:00 2001 From: Alexander Sysoev Date: Mon, 20 Jan 2025 18:56:00 +0100 Subject: [PATCH 2/3] Added `mcp` Ktor server route alongside Application extension --- README.md | 34 +++- build.gradle.kts | 1 + gradle/libs.versions.toml | 2 + .../kotlin/sdk/server/KtorServer.kt | 110 ++++++++++++ .../kotlin/sdk/server/McpKtorServerPlugin.kt | 54 ------ src/jvmTest/kotlin/sse.ktor.kt | 162 ------------------ 6 files changed, 145 insertions(+), 218 deletions(-) create mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt delete mode 100644 src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/McpKtorServerPlugin.kt delete mode 100644 src/jvmTest/kotlin/sse.ktor.kt diff --git a/README.md b/README.md index d4edfefd..55a17ad7 100644 --- a/README.md +++ b/README.md @@ -114,12 +114,13 @@ server.connect(transport) ### Using SSE Transport +Directly in Ktor's `Application`: ```kotlin import io.ktor.server.application.* -import io.modelcontextprotocol.kotlin.sdk.server.MCP +import io.modelcontextprotocol.kotlin.sdk.server.mcp fun Application.module() { - MCP { + mcp { Server( serverInfo = Implementation( name = "example-sse-server", @@ -136,6 +137,35 @@ fun Application.module() { } ``` +Inside a custom Ktor's `Route`: +```kotlin +import io.ktor.server.application.* +import io.ktor.server.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.server.mcp + +fun Application.module() { + install(SSE) + + routing { + route("myRoute") { + mcp { + Server( + serverInfo = Implementation( + name = "example-sse-server", + version = "1.0.0" + ), + options = ServerOptions( + capabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts(listChanged = null), + resources = ServerCapabilities.Resources(subscribe = null, listChanged = null) + ) + ) + ) + } + } + } +} +``` ## Contributing Please see the [contribution guide](CONTRIBUTING.md) and the [Code of conduct](CODE_OF_CONDUCT.md) before contributing. diff --git a/build.gradle.kts b/build.gradle.kts index 961e50b0..96fd3041 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -225,6 +225,7 @@ kotlin { jvmTest { dependencies { implementation(libs.mockk) + implementation(libs.slf4j.simple) } } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9456d331..57d552cd 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,6 +12,7 @@ mockk = "1.13.13" logging = "7.0.0" jreleaser = "1.15.0" binaryCompatibilityValidatorPlugin = "0.17.0" +slf4j = "2.0.16" [libraries] # Kotlinx libraries @@ -30,6 +31,7 @@ kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-cor kotlinx-coroutines-debug = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-debug", version.ref = "coroutines" } ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" } mockk = { group = "io.mockk", name = "mockk", version.ref = "mockk" } +slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt new file mode 100644 index 00000000..572cd355 --- /dev/null +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -0,0 +1,110 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.* +import io.ktor.server.application.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.sse.* +import io.ktor.util.collections.* +import io.ktor.utils.io.KtorDsl + +private val logger = KotlinLogging.logger {} + +@KtorDsl +public fun Routing.mcp(path: String, block: () -> Server) { + route(path) { + mcp(block) + } +} + +/** + * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). + */ +@KtorDsl +public fun Routing.mcp(block: () -> Server) { + val transports = ConcurrentMap() + + sse { + mcpSseEndpoint("", transports, block) + } + + post { + mcpPostEndpoint(transports) + } +} + +@Suppress("FunctionName") +@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.WARNING) +public fun Application.MCP(block: () -> Server) { + mcp(block) +} + +@KtorDsl +public fun Application.mcp(block: () -> Server) { + val transports = ConcurrentMap() + + install(SSE) + + routing { + sse("/sse") { + mcpSseEndpoint("/message", transports, block) + } + + post("/message") { + mcpPostEndpoint(transports) + } + } +} + +private suspend fun ServerSSESession.mcpSseEndpoint( + postEndpoint: String, + transports: ConcurrentMap, + block: () -> Server, +) { + val transport = mcpSseTransport(postEndpoint, transports) + + val server = block() + + server.onClose { + logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } + transports.remove(transport.sessionId) + } + + server.connect(transport) + logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } +} + +internal fun ServerSSESession.mcpSseTransport( + postEndpoint: String, + transports: ConcurrentMap, +): SseServerTransport { + val transport = SseServerTransport(postEndpoint, this) + transports[transport.sessionId] = transport + + logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } + + return transport +} + +internal suspend fun RoutingContext.mcpPostEndpoint( + transports: ConcurrentMap, +) { + val sessionId: String = call.request.queryParameters["sessionId"] + ?: run { + call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") + return + } + + logger.debug { "Received message for sessionId: $sessionId" } + + val transport = transports[sessionId] + if (transport == null) { + logger.warn { "Session not found for sessionId: $sessionId" } + call.respond(HttpStatusCode.NotFound, "Session not found") + return + } + + transport.handlePostMessage(call) + logger.trace { "Message handled for sessionId: $sessionId" } +} diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/McpKtorServerPlugin.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/McpKtorServerPlugin.kt deleted file mode 100644 index e465f2fb..00000000 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/McpKtorServerPlugin.kt +++ /dev/null @@ -1,54 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.server - -import io.github.oshai.kotlinlogging.KotlinLogging -import io.ktor.http.* -import io.ktor.server.application.* -import io.ktor.server.response.* -import io.ktor.server.routing.* -import io.ktor.server.sse.* -import io.ktor.util.collections.* - -private val logger = KotlinLogging.logger {} - -/** - * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). - */ -public fun Application.MCP(block: () -> Server) { - val servers = ConcurrentMap() - - install(SSE) - routing { - sse("/sse") { - val transport = SSEServerTransport("/message", this) - logger.info { "New SSE connection established with sessionId: ${transport.sessionId}" } - - val server = block() - - servers[transport.sessionId] = server - logger.debug { "Server instance created and stored for sessionId: ${transport.sessionId}" } - - server.onCloseCallback = { - logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } - servers.remove(transport.sessionId) - } - - server.connect(transport) - logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } - } - - post("/message") { - val sessionId: String = call.request.queryParameters["sessionId"]!! - logger.debug { "Received message for sessionId: $sessionId" } - - val transport = servers[sessionId]?.transport as? SSEServerTransport - if (transport == null) { - logger.warn { "Session not found for sessionId: $sessionId" } - call.respond(HttpStatusCode.NotFound, "Session not found") - return@post - } - - transport.handlePostMessage(call) - logger.trace { "Message handled for sessionId: $sessionId" } - } - } -} diff --git a/src/jvmTest/kotlin/sse.ktor.kt b/src/jvmTest/kotlin/sse.ktor.kt deleted file mode 100644 index 07e2a15b..00000000 --- a/src/jvmTest/kotlin/sse.ktor.kt +++ /dev/null @@ -1,162 +0,0 @@ -import io.ktor.http.HttpStatusCode -import io.ktor.server.application.ApplicationCall -import io.ktor.server.response.respondText -import io.ktor.server.routing.Route -import io.ktor.server.routing.RoutingContext -import io.ktor.server.routing.application -import io.ktor.server.routing.post -import io.ktor.server.sse.ServerSSESession -import io.ktor.server.sse.sse -import io.ktor.util.AttributeKey -import io.ktor.util.Attributes -import io.modelcontextprotocol.kotlin.sdk.server.SSEServerTransport -import kotlinx.coroutines.CompletableDeferred -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.server.SESSION_ID_PARAM -import io.modelcontextprotocol.kotlin.sdk.server.Server -import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions -import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME - -typealias IncomingHandler = suspend RoutingContext.(forward: suspend () -> Unit) -> Unit - -fun Route.mcpSse( - options: ServerOptions? = null, - incomingPath: String = "", - incomingHandler: (IncomingHandler)? = null, - handler: suspend Server.() -> Unit = {}, -) { - sse { - createMcpServer(this, incomingPath, options, handler) - } - - setupPostRoute(incomingPath, incomingHandler) -} - -fun Route.mcpSse( - path: String, - incomingPath: String = path, - options: ServerOptions? = null, - incomingHandler: (IncomingHandler)? = null, - handler: suspend Server.() -> Unit = {}, -) { - sse(path) { - createMcpServer(this, incomingPath, options, handler) - } - - setupPostRoute(incomingPath, incomingHandler) -} - -fun Route.mcpSseTransport( - incomingPath: String = "", - incomingHandler: (IncomingHandler)? = null, - handler: suspend SSEServerTransport.() -> Unit = {}, -) { - sse { - val transport = createMcpTransport(this, incomingPath) - handler(transport) - transport.start() - transport.close() - } - - setupPostRoute(incomingPath, incomingHandler) -} - -fun Route.mcpSseTransport( - path: String, - incomingPath: String = path, - incomingHandler: (IncomingHandler)? = null, - handler: suspend SSEServerTransport.() -> Unit = {}, -) { - sse(path) { - val transport = createMcpTransport(this, incomingPath) - transport.start() - handler(transport) - transport.close() - } - - setupPostRoute(incomingPath, incomingHandler) -} - -internal val McpServersKey = AttributeKey("mcp-servers") - -private fun String.asAttributeKey() = AttributeKey(this) - -private suspend fun Route.forwardMcpMessage(call: ApplicationCall) { - val sessionId = call.request.queryParameters[SESSION_ID_PARAM] - ?.asAttributeKey() - ?: run { - call.sessionNotFound() - return - } - - application.attributes.getOrNull(McpServersKey) - ?.get(sessionId) - ?.handlePostMessage(call) - ?: call.sessionNotFound() -} - -private suspend fun ApplicationCall.sessionNotFound() { - respondText("Session not found", status = HttpStatusCode.NotFound) -} - -private fun Route.setupPostRoute(incomingPath: String, incomingHandler: IncomingHandler?) { - post(incomingPath) { - if (incomingHandler != null) { - incomingHandler { - forwardMcpMessage(call) - } - } else { - forwardMcpMessage(call) - } - } -} - -private suspend fun Route.createMcpServer( - session: ServerSSESession, - incomingPath: String, - options: ServerOptions?, - handler: suspend Server.() -> Unit = {}, -) { - val transport = createMcpTransport(session, incomingPath) - - val closed = CompletableDeferred() - - val server = Server( - serverInfo = Implementation( - name = IMPLEMENTATION_NAME, - version = LIB_VERSION, - ), - options = options ?: ServerOptions( - capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = null), - resources = ServerCapabilities.Resources(subscribe = null, listChanged = null), - tools = ServerCapabilities.Tools(listChanged = null), - ) - ), - onCloseCallback = { - closed.complete(Unit) - }, - ) - - server.connect(transport) - handler(server) - server.close() -} - -private fun Route.createMcpTransport( - session: ServerSSESession, - incomingPath: String, -): SSEServerTransport { - val transport = SSEServerTransport( - endpoint = incomingPath, - session = session, - ) - - application.attributes - .computeIfAbsent(McpServersKey) { Attributes(concurrent = true) } - .put(transport.sessionId.asAttributeKey(), transport) - - return transport -} From cc501db1bfb8f88c85bd5b5cb157b75029f99931 Mon Sep 17 00:00:00 2001 From: Alexander Sysoev Date: Mon, 20 Jan 2025 19:12:48 +0100 Subject: [PATCH 3/3] Fix AbstractTransport to not skip messages --- .../kotlin/sdk/shared/Transport.kt | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt index 96f531e7..514f3f26 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import kotlinx.coroutines.CompletableDeferred /** * Describes the minimal contract for a MCP transport that a client or server can communicate over. @@ -54,8 +55,17 @@ public interface Transport { @Suppress("PropertyName") public abstract class AbstractTransport : Transport { protected var _onClose: (() -> Unit) = {} + private set protected var _onError: ((Throwable) -> Unit) = {} - protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {} + private set + + // to not skip messages + private val _onMessageInitialized = CompletableDeferred() + protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { + _onMessageInitialized.await() + _onMessage.invoke(it) + } + private set override fun onClose(block: () -> Unit) { val old = _onClose @@ -74,10 +84,16 @@ public abstract class AbstractTransport : Transport { } override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - val old = _onMessage + val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) { + true -> _onMessage + false -> { _ -> } + } + _onMessage = { message -> old(message) block(message) } + + _onMessageInitialized.complete(Unit) } }