diff --git a/build.gradle.kts b/build.gradle.kts index ec9e8eee..18f8f946 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -11,7 +11,6 @@ plugins { alias(libs.plugins.kotlin.serialization) alias(libs.plugins.dokka) alias(libs.plugins.jreleaser) - alias(libs.plugins.atomicfu) `maven-publish` alias(libs.plugins.kotlinx.binary.compatibility.validator) } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 2265caea..bbb40987 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,8 +1,7 @@ [versions] # plugins version -kotlin = "2.0.21" +kotlin = "2.1.20" dokka = "2.0.0" -atomicfu = "0.26.1" # libraries version serialization = "1.7.3" @@ -40,5 +39,4 @@ kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" } jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"} -atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" } kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index 91a9f85b..19405c57 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -8,10 +8,10 @@ import io.ktor.http.* import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import kotlinx.atomicfu.AtomicBoolean -import kotlinx.atomicfu.atomic import kotlinx.coroutines.* import kotlinx.serialization.encodeToString +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.properties.Delegates import kotlin.time.Duration @@ -22,6 +22,7 @@ public typealias SSEClientTransport = SseClientTransport * Client transport for SSE: this will connect to a server using Server-Sent Events for receiving * messages and make separate POST requests for sending messages. */ +@OptIn(ExperimentalAtomicApi::class) public class SseClientTransport( private val client: HttpClient, private val urlString: String?, @@ -32,7 +33,7 @@ public class SseClientTransport( CoroutineScope(session.coroutineContext + SupervisorJob()) } - private val initialized: AtomicBoolean = atomic(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) private var session: ClientSSESession by Delegates.notNull() private val endpoint = CompletableDeferred() @@ -127,7 +128,7 @@ public class SseClientTransport( } override suspend fun close() { - if (!initialized.value) { + if (!initialized.load()) { error("SSEClientTransport is not initialized!") } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index f579f92a..77acc058 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -5,17 +5,12 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage -import kotlinx.atomicfu.AtomicBoolean -import kotlinx.atomicfu.atomic import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.consumeEach -import kotlinx.io.Buffer -import kotlinx.io.Sink -import kotlinx.io.Source -import kotlinx.io.buffered -import kotlinx.io.readByteArray -import kotlinx.io.writeString +import kotlinx.io.* +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.coroutines.CoroutineContext /** @@ -27,6 +22,7 @@ import kotlin.coroutines.CoroutineContext * @param input The input stream where messages are received. * @param output The output stream where messages are sent. */ +@OptIn(ExperimentalAtomicApi::class) public class StdioClientTransport( private val input: Source, private val output: Sink @@ -37,7 +33,7 @@ public class StdioClientTransport( CoroutineScope(ioCoroutineContext + SupervisorJob()) } private var job: Job? = null - private val initialized: AtomicBoolean = atomic(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) private val sendChannel = Channel(Channel.UNLIMITED) private val readBuffer = ReadBuffer() @@ -96,7 +92,7 @@ public class StdioClientTransport( } override suspend fun send(message: JSONRPCMessage) { - if (!initialized.value) { + if (!initialized.load()) { error("Transport not started") } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index 32039cde..67ac4344 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -8,10 +8,10 @@ import io.ktor.server.sse.* import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import kotlinx.atomicfu.AtomicBoolean -import kotlinx.atomicfu.atomic import kotlinx.coroutines.job import kotlinx.serialization.encodeToString +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -25,11 +25,12 @@ public typealias SSEServerTransport = SseServerTransport * * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. */ +@OptIn(ExperimentalAtomicApi::class) public class SseServerTransport( private val endpoint: String, private val session: ServerSSESession, ) : AbstractTransport() { - private val initialized: AtomicBoolean = atomic(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) @OptIn(ExperimentalUuidApi::class) public val sessionId: String = Uuid.random().toString() @@ -63,7 +64,7 @@ public class SseServerTransport( * This should be called when a POST request is made to send a message to the server. */ public suspend fun handlePostMessage(call: ApplicationCall) { - if (!initialized.value) { + if (!initialized.load()) { val message = "SSE connection not established" call.respondText(message, status = HttpStatusCode.InternalServerError) _onError.invoke(IllegalStateException(message)) @@ -112,7 +113,7 @@ public class SseServerTransport( } override suspend fun send(message: JSONRPCMessage) { - if (!initialized.value) { + if (!initialized.load()) { throw error("Not connected") } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index d09d160d..a7e371a2 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -5,18 +5,11 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage -import kotlinx.atomicfu.AtomicBoolean -import kotlinx.atomicfu.atomic -import kotlinx.atomicfu.locks.ReentrantLock -import kotlinx.atomicfu.locks.withLock import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel -import kotlinx.io.Buffer -import kotlinx.io.Sink -import kotlinx.io.Source -import kotlinx.io.buffered -import kotlinx.io.readByteArray -import kotlinx.io.writeString +import kotlinx.io.* +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.coroutines.CoroutineContext /** @@ -24,24 +17,26 @@ import kotlin.coroutines.CoroutineContext * * Reads from System.in and writes to System.out. */ +@OptIn(ExperimentalAtomicApi::class) public class StdioServerTransport( - private val inputStream: Source, //BufferedInputStream = BufferedInputStream(System.`in`), - outputStream: Sink //PrintStream = System.out + private val inputStream: Source, + outputStream: Sink ) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val readBuffer = ReadBuffer() - private val initialized: AtomicBoolean = atomic(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) private var readingJob: Job? = null + private var sendingJob: Job? = null private val coroutineContext: CoroutineContext = Dispatchers.IO + SupervisorJob() private val scope = CoroutineScope(coroutineContext) private val readChannel = Channel(Channel.UNLIMITED) + private val writeChannel = Channel(Channel.UNLIMITED) private val outputWriter = outputStream.buffered() - private val lock = ReentrantLock() override suspend fun start() { - if (!initialized.compareAndSet(false, true)) { + if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error("StdioServerTransport already started!") } @@ -80,6 +75,20 @@ public class StdioServerTransport( _onError.invoke(e) } } + + // Launch a coroutine to handle message sending + sendingJob = scope.launch { + try { + for (message in writeChannel) { + val json = serializeMessage(message) + outputWriter.writeString(json) + outputWriter.flush() + } + } catch (e: Throwable) { + logger.error(e) { "Error writing to stdout" } + _onError.invoke(e) + } + } } private suspend fun processReadBuffer() { @@ -102,22 +111,20 @@ public class StdioServerTransport( } override suspend fun close() { - if (!initialized.compareAndSet(true, false)) return + if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return // Cancel reading job and close channel readingJob?.cancel() // ToDO("was cancel and join") + sendingJob?.cancel() + readChannel.close() + writeChannel.close() readBuffer.clear() _onClose.invoke() } override suspend fun send(message: JSONRPCMessage) { - val json = serializeMessage(message) - lock.withLock { - // You may need to add Content-Length headers before the message if using the LSP framing protocol - outputWriter.writeString(json) - outputWriter.flush() - } + writeChannel.send(message) } } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index cbba7286..baf373a3 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -1,20 +1,12 @@ package io.modelcontextprotocol.kotlin.sdk.shared -import io.ktor.websocket.Frame -import io.ktor.websocket.WebSocketSession -import io.ktor.websocket.close -import io.ktor.websocket.readText +import io.ktor.websocket.* import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import kotlinx.atomicfu.AtomicBoolean -import kotlinx.atomicfu.atomic -import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.InternalCoroutinesApi -import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.* import kotlinx.coroutines.channels.ClosedReceiveChannelException -import kotlinx.coroutines.job -import kotlinx.coroutines.launch import kotlinx.serialization.encodeToString +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi internal const val MCP_SUBPROTOCOL = "mcp" @@ -22,12 +14,13 @@ internal const val MCP_SUBPROTOCOL = "mcp" * Abstract class representing a WebSocket transport for the Model Context Protocol (MCP). * Handles communication over a WebSocket session. */ +@OptIn(ExperimentalAtomicApi::class) public abstract class WebSocketMcpTransport : AbstractTransport() { private val scope by lazy { CoroutineScope(session.coroutineContext + SupervisorJob()) } - private val initialized: AtomicBoolean = atomic(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) /** * The WebSocket session used for communication. */ @@ -83,7 +76,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } override suspend fun send(message: JSONRPCMessage) { - if (!initialized.value) { + if (!initialized.load()) { error("Not connected") } @@ -91,7 +84,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } override suspend fun close() { - if (!initialized.value) { + if (!initialized.load()) { error("Not connected") } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index d2fa2cb2..c5e5456e 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -3,14 +3,14 @@ package io.modelcontextprotocol.kotlin.sdk import io.modelcontextprotocol.kotlin.sdk.shared.McpJson -import kotlinx.atomicfu.AtomicLong -import kotlinx.atomicfu.atomic import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.encodeToJsonElement -import kotlin.jvm.JvmInline +import kotlin.concurrent.atomics.AtomicLong +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.concurrent.atomics.incrementAndFetch public const val LATEST_PROTOCOL_VERSION: String = "2024-11-05" @@ -21,7 +21,8 @@ public val SUPPORTED_PROTOCOL_VERSIONS: Array = arrayOf( public const val JSONRPC_VERSION: String = "2.0" -private val REQUEST_MESSAGE_ID: AtomicLong = atomic(0L) +@OptIn(ExperimentalAtomicApi::class) +private val REQUEST_MESSAGE_ID: AtomicLong = AtomicLong(0L) /** * A progress token, used to associate progress notifications with the original request. @@ -132,7 +133,7 @@ internal fun Request.toJSON(): JSONRPCRequest { */ internal fun JSONRPCRequest.fromJSON(): Request? { val serializer = selectRequestDeserializer(method) - val params = params ?: return null + val params = params return McpJson.decodeFromJsonElement(serializer, params) } @@ -211,9 +212,10 @@ public sealed interface JSONRPCMessage /** * A request that expects a response. */ +@OptIn(ExperimentalAtomicApi::class) @Serializable public data class JSONRPCRequest( - val id: RequestId = RequestId.NumberId(REQUEST_MESSAGE_ID.incrementAndGet()), + val id: RequestId = RequestId.NumberId(REQUEST_MESSAGE_ID.incrementAndFetch()), val method: String, val params: JsonElement = EmptyJsonObject, val jsonrpc: String = JSONRPC_VERSION,