@@ -8,7 +8,6 @@ import io.ktor.server.response.respond
88import io.ktor.server.routing.Routing
99import io.ktor.server.routing.RoutingContext
1010import io.ktor.server.routing.post
11- import io.ktor.server.routing.route
1211import io.ktor.server.routing.routing
1312import io.ktor.server.sse.SSE
1413import io.ktor.server.sse.ServerSSESession
@@ -18,39 +17,40 @@ import kotlinx.atomicfu.AtomicRef
1817import kotlinx.atomicfu.atomic
1918import kotlinx.atomicfu.update
2019import kotlinx.collections.immutable.PersistentMap
21- import kotlinx.collections.immutable.persistentMapOf
20+ import kotlinx.collections.immutable.toPersistentMap
2221
2322private val logger = KotlinLogging .logger {}
2423
25- @KtorDsl
26- public fun Routing.mcp (path : String , block : () -> Server ) {
27- route(path) {
28- mcp(block)
24+ internal class SseTransportManager (transports : Map <String , SseServerTransport > = emptyMap()) {
25+ private val transports: AtomicRef <PersistentMap <String , SseServerTransport >> = atomic(transports.toPersistentMap())
26+
27+ fun getTransport (sessionId : String ): SseServerTransport ? = transports.value[sessionId]
28+
29+ fun addTransport (transport : SseServerTransport ) {
30+ transports.update { it.put(transport.sessionId, transport) }
31+ }
32+
33+ fun removeTransport (sessionId : String ) {
34+ transports.update { it.remove(sessionId) }
2935 }
3036}
3137
32- /* *
33- * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
34- */
38+ /*
39+ * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
40+ */
3541@KtorDsl
3642public fun Routing.mcp (block : () -> Server ) {
37- val transports = atomic(persistentMapOf< String , SseServerTransport >() )
43+ val sseTransportManager = SseTransportManager ( )
3844
3945 sse {
40- mcpSseEndpoint(" " , transports , block)
46+ mcpSseEndpoint(" " , sseTransportManager , block)
4147 }
4248
4349 post {
44- mcpPostEndpoint(transports )
50+ mcpPostEndpoint(sseTransportManager )
4551 }
4652}
4753
48- @Suppress(" FunctionName" )
49- @Deprecated(" Use mcp() instead" , ReplaceWith (" mcp(block)" ), DeprecationLevel .WARNING )
50- public fun Application.MCP (block : () -> Server ) {
51- mcp(block)
52- }
53-
5454@KtorDsl
5555public fun Application.mcp (block : () -> Server ) {
5656 install(SSE )
@@ -62,16 +62,16 @@ public fun Application.mcp(block: () -> Server) {
6262
6363internal suspend fun ServerSSESession.mcpSseEndpoint (
6464 postEndpoint : String ,
65- transports : AtomicRef < PersistentMap < String , SseServerTransport >> ,
65+ sseTransportManager : SseTransportManager ,
6666 block : () -> Server ,
6767) {
68- val transport = mcpSseTransport(postEndpoint, transports )
68+ val transport = mcpSseTransport(postEndpoint, sseTransportManager )
6969
7070 val server = block()
7171
7272 server.onClose {
7373 logger.info { " Server connection closed for sessionId: ${transport.sessionId} " }
74- transports.update { it.remove (transport.sessionId) }
74+ sseTransportManager.removeTransport (transport.sessionId)
7575 }
7676
7777 server.connectSession(transport)
@@ -81,17 +81,17 @@ internal suspend fun ServerSSESession.mcpSseEndpoint(
8181
8282internal fun ServerSSESession.mcpSseTransport (
8383 postEndpoint : String ,
84- transports : AtomicRef < PersistentMap < String , SseServerTransport >> ,
84+ sseTransportManager : SseTransportManager ,
8585): SseServerTransport {
8686 val transport = SseServerTransport (postEndpoint, this )
87- transports.update { it.put (transport.sessionId, transport) }
87+ sseTransportManager.addTransport (transport)
8888 logger.info { " New SSE connection established and stored with sessionId: ${transport.sessionId} " }
8989
9090 return transport
9191}
9292
9393internal suspend fun RoutingContext.mcpPostEndpoint (
94- transports : AtomicRef < PersistentMap < String , SseServerTransport >> ,
94+ sseTransportManager : SseTransportManager ,
9595) {
9696 val sessionId: String = call.request.queryParameters[" sessionId" ]
9797 ? : run {
@@ -101,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(
101101
102102 logger.debug { " Received message for sessionId: $sessionId " }
103103
104- val transport = transports.value[ sessionId]
104+ val transport = sseTransportManager.getTransport( sessionId)
105105 if (transport == null ) {
106106 logger.warn { " Session not found for sessionId: $sessionId " }
107107 call.respond(HttpStatusCode .NotFound , " Session not found" )
0 commit comments