Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ server.connect(transport)

### Using SSE Transport

Directly in Ktor's `Application`:
```kotlin
import io.ktor.server.application.*
import io.modelcontextprotocol.kotlin.sdk.server.MCP
import io.modelcontextprotocol.kotlin.sdk.server.mcp

fun Application.module() {
MCP {
mcp {
Server(
serverInfo = Implementation(
name = "example-sse-server",
Expand All @@ -136,6 +137,35 @@ fun Application.module() {
}
```

Inside a custom Ktor's `Route`:
```kotlin
import io.ktor.server.application.*
import io.ktor.server.sse.SSE
import io.modelcontextprotocol.kotlin.sdk.server.mcp

fun Application.module() {
install(SSE)

routing {
route("myRoute") {
mcp {
Server(
serverInfo = Implementation(
name = "example-sse-server",
version = "1.0.0"
),
options = ServerOptions(
capabilities = ServerCapabilities(
prompts = ServerCapabilities.Prompts(listChanged = null),
resources = ServerCapabilities.Resources(subscribe = null, listChanged = null)
)
)
)
}
}
}
}
```
## Contributing

Please see the [contribution guide](CONTRIBUTING.md) and the [Code of conduct](CODE_OF_CONDUCT.md) before contributing.
Expand Down
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ kotlin {
jvmTest {
dependencies {
implementation(libs.mockk)
implementation(libs.slf4j.simple)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mockk = "1.13.13"
logging = "7.0.0"
jreleaser = "1.15.0"
binaryCompatibilityValidatorPlugin = "0.17.0"
slf4j = "2.0.16"

[libraries]
# Kotlinx libraries
Expand All @@ -30,6 +31,7 @@ kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-cor
kotlinx-coroutines-debug = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-debug", version.ref = "coroutines" }
ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" }
mockk = { group = "io.mockk", name = "mockk", version.ref = "mockk" }
slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" }



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public fun HttpClient.mcpSseTransport(
urlString: String? = null,
reconnectionTime: Duration? = null,
requestBuilder: HttpRequestBuilder.() -> Unit = {},
): SSEClientTransport = SSEClientTransport(this, urlString, reconnectionTime, requestBuilder)
): SseClientTransport = SseClientTransport(this, urlString, reconnectionTime, requestBuilder)

/**
* Creates and connects an MCP client over SSE using the provided HttpClient.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,28 @@ import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.*
import kotlinx.serialization.encodeToString
import kotlin.properties.Delegates
import kotlin.time.Duration

@Deprecated("Use SseClientTransport instead", ReplaceWith("SseClientTransport"), DeprecationLevel.WARNING)
public typealias SSEClientTransport = SseClientTransport

/**
* Client transport for SSE: this will connect to a server using Server-Sent Events for receiving
* messages and make separate POST requests for sending messages.
*/
public class SSEClientTransport(
public class SseClientTransport(
private val client: HttpClient,
private val urlString: String?,
private val reconnectionTime: Duration? = null,
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
) : Transport {
) : AbstractTransport() {
private val scope by lazy {
CoroutineScope(session.coroutineContext + SupervisorJob())
}
Expand All @@ -33,10 +36,6 @@ public class SSEClientTransport(
private var session: ClientSSESession by Delegates.notNull()
private val endpoint = CompletableDeferred<String>()

private var _onClose: (() -> Unit) = {}
private var _onError: ((Throwable) -> Unit) = {}
private var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {}

private var job: Job? = null

private val baseUrl by lazy {
Expand Down Expand Up @@ -136,28 +135,4 @@ public class SSEClientTransport(
_onClose()
job?.cancelAndJoin()
}

override fun onClose(block: () -> Unit) {
val old = _onClose
_onClose = {
old()
block()
}
}

override fun onError(block: (Throwable) -> Unit) {
val old = _onError
_onError = { e ->
old(e)
block(e)
}
}

override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
val old = _onMessage
_onMessage = { message ->
old(message)
block(message)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
Expand All @@ -30,7 +30,7 @@ import kotlin.coroutines.CoroutineContext
public class StdioClientTransport(
private val input: Source,
private val output: Sink
) : Transport {
) : AbstractTransport() {
private val logger = KotlinLogging.logger {}
private val ioCoroutineContext: CoroutineContext = Dispatchers.IO
private val scope by lazy {
Expand All @@ -41,10 +41,6 @@ public class StdioClientTransport(
private val sendChannel = Channel<JSONRPCMessage>(Channel.UNLIMITED)
private val readBuffer = ReadBuffer()

override var onClose: (() -> Unit)? = null
override var onError: ((Throwable) -> Unit)? = null
override var onMessage: (suspend ((JSONRPCMessage) -> Unit))? = null

override suspend fun start() {
if (!initialized.compareAndSet(false, true)) {
error("StdioClientTransport already started!")
Expand All @@ -70,7 +66,7 @@ public class StdioClientTransport(
}
}
} catch (e: Exception) {
onError?.invoke(e)
_onError.invoke(e)
logger.error(e) { "Error reading from input stream" }
}
}
Expand All @@ -85,7 +81,7 @@ public class StdioClientTransport(
}
} catch (e: Throwable) {
if (isActive) {
onError?.invoke(e)
_onError.invoke(e)
logger.error(e) { "Error writing to output stream" }
}
} finally {
Expand All @@ -95,7 +91,7 @@ public class StdioClientTransport(

readJob.join()
writeJob.cancelAndJoin()
onClose?.invoke()
_onClose.invoke()
}
}

Expand All @@ -116,16 +112,16 @@ public class StdioClientTransport(
output.close()
readBuffer.clear()
sendChannel.close()
onClose?.invoke()
_onClose.invoke()
}

private suspend fun processReadBuffer() {
while (true) {
val msg = readBuffer.readMessage() ?: break
try {
onMessage?.invoke(msg)
_onMessage.invoke(msg)
} catch (e: Throwable) {
onError?.invoke(e)
_onError.invoke(e)
logger.error(e) { "Error processing message." }
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package io.modelcontextprotocol.kotlin.sdk.server

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sse.*
import io.ktor.util.collections.*
import io.ktor.utils.io.KtorDsl

private val logger = KotlinLogging.logger {}

@KtorDsl
public fun Routing.mcp(path: String, block: () -> Server) {
route(path) {
mcp(block)
}
}

/**
* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
*/
@KtorDsl
public fun Routing.mcp(block: () -> Server) {
val transports = ConcurrentMap<String, SseServerTransport>()

sse {
mcpSseEndpoint("", transports, block)
}

post {
mcpPostEndpoint(transports)
}
}

@Suppress("FunctionName")
@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.WARNING)
public fun Application.MCP(block: () -> Server) {
mcp(block)
}

@KtorDsl
public fun Application.mcp(block: () -> Server) {
val transports = ConcurrentMap<String, SseServerTransport>()

install(SSE)

routing {
sse("/sse") {
mcpSseEndpoint("/message", transports, block)
}

post("/message") {
mcpPostEndpoint(transports)
}
}
}

private suspend fun ServerSSESession.mcpSseEndpoint(
postEndpoint: String,
transports: ConcurrentMap<String, SseServerTransport>,
block: () -> Server,
) {
val transport = mcpSseTransport(postEndpoint, transports)

val server = block()

server.onClose {
logger.info { "Server connection closed for sessionId: ${transport.sessionId}" }
transports.remove(transport.sessionId)
}

server.connect(transport)
logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" }
}

internal fun ServerSSESession.mcpSseTransport(
postEndpoint: String,
transports: ConcurrentMap<String, SseServerTransport>,
): SseServerTransport {
val transport = SseServerTransport(postEndpoint, this)
transports[transport.sessionId] = transport

logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }

return transport
}

internal suspend fun RoutingContext.mcpPostEndpoint(
transports: ConcurrentMap<String, SseServerTransport>,
) {
val sessionId: String = call.request.queryParameters["sessionId"]
?: run {
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
return
}

logger.debug { "Received message for sessionId: $sessionId" }

val transport = transports[sessionId]
if (transport == null) {
logger.warn { "Session not found for sessionId: $sessionId" }
call.respond(HttpStatusCode.NotFound, "Session not found")
return
}

transport.handlePostMessage(call)
logger.trace { "Message handled for sessionId: $sessionId" }
}

This file was deleted.

Loading