@@ -13,8 +13,12 @@ import io.ktor.server.routing.routing
1313import io.ktor.server.sse.SSE
1414import io.ktor.server.sse.ServerSSESession
1515import io.ktor.server.sse.sse
16- import io.ktor.util.collections.ConcurrentMap
1716import io.ktor.utils.io.KtorDsl
17+ import kotlinx.atomicfu.AtomicRef
18+ import kotlinx.atomicfu.atomic
19+ import kotlinx.atomicfu.update
20+ import kotlinx.collections.immutable.PersistentMap
21+ import kotlinx.collections.immutable.persistentMapOf
1822
1923private val logger = KotlinLogging .logger {}
2024
@@ -30,7 +34,7 @@ public fun Routing.mcp(path: String, block: () -> Server) {
3034 */
3135@KtorDsl
3236public fun Routing.mcp (block : () -> Server ) {
33- val transports = ConcurrentMap <String , SseServerTransport >()
37+ val transports = atomic(persistentMapOf <String , SseServerTransport >() )
3438
3539 sse {
3640 mcpSseEndpoint(" " , transports, block)
@@ -49,24 +53,16 @@ public fun Application.MCP(block: () -> Server) {
4953
5054@KtorDsl
5155public fun Application.mcp (block : () -> Server ) {
52- val transports = ConcurrentMap <String , SseServerTransport >()
53-
5456 install(SSE )
5557
5658 routing {
57- sse(" /sse" ) {
58- mcpSseEndpoint(" /message" , transports, block)
59- }
60-
61- post(" /message" ) {
62- mcpPostEndpoint(transports)
63- }
59+ mcp(block)
6460 }
6561}
6662
67- private suspend fun ServerSSESession.mcpSseEndpoint (
63+ internal suspend fun ServerSSESession.mcpSseEndpoint (
6864 postEndpoint : String ,
69- transports : ConcurrentMap < String , SseServerTransport >,
65+ transports : AtomicRef < PersistentMap < String , SseServerTransport > >,
7066 block : () -> Server ,
7167) {
7268 val transport = mcpSseTransport(postEndpoint, transports)
@@ -75,27 +71,27 @@ private suspend fun ServerSSESession.mcpSseEndpoint(
7571
7672 server.onClose {
7773 logger.info { " Server connection closed for sessionId: ${transport.sessionId} " }
78- transports.remove(transport.sessionId)
74+ transports.update { it. remove(transport.sessionId) }
7975 }
8076
81- server.connect(transport)
77+ server.connectSession(transport)
78+
8279 logger.debug { " Server connected to transport for sessionId: ${transport.sessionId} " }
8380}
8481
8582internal fun ServerSSESession.mcpSseTransport (
8683 postEndpoint : String ,
87- transports : ConcurrentMap < String , SseServerTransport >,
84+ transports : AtomicRef < PersistentMap < String , SseServerTransport > >,
8885): SseServerTransport {
8986 val transport = SseServerTransport (postEndpoint, this )
90- transports[transport.sessionId] = transport
91-
87+ transports.update { it.put(transport.sessionId, transport) }
9288 logger.info { " New SSE connection established and stored with sessionId: ${transport.sessionId} " }
9389
9490 return transport
9591}
9692
9793internal suspend fun RoutingContext.mcpPostEndpoint (
98- transports : ConcurrentMap < String , SseServerTransport >,
94+ transports : AtomicRef < PersistentMap < String , SseServerTransport > >,
9995) {
10096 val sessionId: String = call.request.queryParameters[" sessionId" ]
10197 ? : run {
@@ -105,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(
105101
106102 logger.debug { " Received message for sessionId: $sessionId " }
107103
108- val transport = transports[sessionId]
104+ val transport = transports.value [sessionId]
109105 if (transport == null ) {
110106 logger.warn { " Session not found for sessionId: $sessionId " }
111107 call.respond(HttpStatusCode .NotFound , " Session not found" )
0 commit comments