Skip to content

Commit 97038b3

Browse files
committed
Move session management logic to ServerSessionRegistry.kt
Rewrite server session redirection to receivers
1 parent f356f94 commit 97038b3

File tree

3 files changed

+89
-44
lines changed

3 files changed

+89
-44
lines changed

kotlin-sdk-server/api/kotlin-sdk-server.api

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server {
7575
public final fun getPrompts ()Ljava/util/Map;
7676
public final fun getResources ()Ljava/util/Map;
7777
protected final fun getServerInfo ()Lio/modelcontextprotocol/kotlin/sdk/Implementation;
78-
public final fun getSession (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;
79-
public final fun getSessionOrThrow (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/server/ServerSession;
8078
public final fun getSessions ()Ljava/util/Map;
8179
public final fun getTools ()Ljava/util/Map;
8280
public final fun listRoots (Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations
3636
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
3737
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
3838
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
39-
import kotlinx.atomicfu.atomic
40-
import kotlinx.atomicfu.update
41-
import kotlinx.collections.immutable.persistentMapOf
4239
import kotlinx.coroutines.CancellationException
4340
import kotlinx.serialization.json.JsonObject
4441

@@ -88,24 +85,13 @@ public open class Server(
8885
block: Server.() -> Unit = {},
8986
) : this(serverInfo, options, { instructions }, block)
9087

91-
private val sessionRegistry = atomic(persistentMapOf<String, ServerSession>())
88+
private val sessionRegistry = ServerSessionRegistry()
9289

9390
/**
94-
* Returns a read-only view of the current server sessions.
91+
* Provides a snapshot of all sessions currently registered in the server
9592
*/
96-
public val sessions: Map<String, ServerSession>
97-
get() = sessionRegistry.value
98-
99-
/**
100-
* Gets a server session by its ID.
101-
*/
102-
public fun getSession(sessionId: String): ServerSession? = sessions[sessionId]
103-
104-
/**
105-
* Gets a server session by its ID or throws an exception if the session doesn't exist.
106-
*/
107-
public fun getSessionOrThrow(sessionId: String): ServerSession =
108-
sessions[sessionId] ?: throw IllegalArgumentException("Session not found: $sessionId")
93+
public val sessions: Map<ServerSessionKey, ServerSession>
94+
get() = sessionRegistry.sessions
10995

11096
@Suppress("ktlint:standard:backing-property-naming")
11197
private var _onInitialized: (() -> Unit) = {}
@@ -200,12 +186,12 @@ public open class Server(
200186
// Register cleanup handler to remove session from list when it closes
201187
session.onClose {
202188
logger.debug { "Removing closed session from active sessions list" }
203-
sessionRegistry.update { sessions -> sessions.remove(session.sessionId) }
189+
sessionRegistry.removeSession(session.sessionId)
204190
}
205191
logger.debug { "Server session connecting to transport" }
206192
session.connect(transport)
207193
logger.debug { "Server session successfully connected to transport" }
208-
sessionRegistry.update { sessions -> sessions.put(session.sessionId, session) }
194+
sessionRegistry.addSession(session)
209195

210196
_onConnect()
211197
return session
@@ -571,9 +557,8 @@ public open class Server(
571557
* Triggers [ServerSession.ping] request for session by provided [sessionId].
572558
* @param sessionId The session ID to ping
573559
*/
574-
public suspend fun ping(sessionId: String): EmptyRequestResult {
575-
val session = getSessionOrThrow(sessionId)
576-
return session.ping()
560+
public suspend fun ping(sessionId: String): EmptyRequestResult = with(sessionRegistry.getSession(sessionId)) {
561+
ping()
577562
}
578563

579564
/**
@@ -589,9 +574,8 @@ public open class Server(
589574
sessionId: String,
590575
params: CreateMessageRequest,
591576
options: RequestOptions? = null,
592-
): CreateMessageResult {
593-
val session = getSessionOrThrow(sessionId)
594-
return session.request(params, options)
577+
): CreateMessageResult = with(sessionRegistry.getSession(sessionId)) {
578+
request(params, options)
595579
}
596580

597581
/**
@@ -607,9 +591,8 @@ public open class Server(
607591
sessionId: String,
608592
params: JsonObject = EmptyJsonObject,
609593
options: RequestOptions? = null,
610-
): ListRootsResult {
611-
val session = getSessionOrThrow(sessionId)
612-
return session.listRoots(params, options)
594+
): ListRootsResult = with(sessionRegistry.getSession(sessionId)) {
595+
listRoots(params, options)
613596
}
614597

615598
/**
@@ -627,9 +610,8 @@ public open class Server(
627610
message: String,
628611
requestedSchema: RequestedSchema,
629612
options: RequestOptions? = null,
630-
): CreateElicitationResult {
631-
val session = getSessionOrThrow(sessionId)
632-
return session.createElicitation(message, requestedSchema, options)
613+
): CreateElicitationResult = with(sessionRegistry.getSession(sessionId)) {
614+
createElicitation(message, requestedSchema, options)
633615
}
634616

635617
/**
@@ -639,8 +621,9 @@ public open class Server(
639621
* @param notification The logging message notification.
640622
*/
641623
public suspend fun sendLoggingMessage(sessionId: String, notification: LoggingMessageNotification) {
642-
val session = getSessionOrThrow(sessionId)
643-
session.sendLoggingMessage(notification)
624+
with(sessionRegistry.getSession(sessionId)) {
625+
sendLoggingMessage(notification)
626+
}
644627
}
645628

646629
/**
@@ -650,8 +633,9 @@ public open class Server(
650633
* @param notification Details of the updated resource.
651634
*/
652635
public suspend fun sendResourceUpdated(sessionId: String, notification: ResourceUpdatedNotification) {
653-
val session = getSessionOrThrow(sessionId)
654-
session.sendResourceUpdated(notification)
636+
with(sessionRegistry.getSession(sessionId)) {
637+
sendResourceUpdated(notification)
638+
}
655639
}
656640

657641
/**
@@ -660,8 +644,9 @@ public open class Server(
660644
* @param sessionId The session ID to send the resource list changed notification to.
661645
*/
662646
public suspend fun sendResourceListChanged(sessionId: String) {
663-
val session = getSessionOrThrow(sessionId)
664-
session.sendResourceListChanged()
647+
with(sessionRegistry.getSession(sessionId)) {
648+
sendResourceListChanged()
649+
}
665650
}
666651

667652
/**
@@ -670,8 +655,9 @@ public open class Server(
670655
* @param sessionId The session ID to send the tool list changed notification to.
671656
*/
672657
public suspend fun sendToolListChanged(sessionId: String) {
673-
val session = getSessionOrThrow(sessionId)
674-
session.sendToolListChanged()
658+
with(sessionRegistry.getSession(sessionId)) {
659+
sendToolListChanged()
660+
}
675661
}
676662

677663
/**
@@ -680,8 +666,9 @@ public open class Server(
680666
* @param sessionId The session ID to send the prompt list changed notification to.
681667
*/
682668
public suspend fun sendPromptListChanged(sessionId: String) {
683-
val session = getSessionOrThrow(sessionId)
684-
session.sendPromptListChanged()
669+
with(sessionRegistry.getSession(sessionId)) {
670+
sendPromptListChanged()
671+
}
685672
}
686673
// End the ServerSession redirection section
687674
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package io.modelcontextprotocol.kotlin.sdk.server
2+
3+
import io.github.oshai.kotlinlogging.KotlinLogging
4+
import kotlinx.atomicfu.atomic
5+
import kotlinx.atomicfu.update
6+
import kotlinx.collections.immutable.persistentMapOf
7+
8+
internal typealias ServerSessionKey = String
9+
10+
/**
11+
* Represents a registry for managing server sessions.
12+
*/
13+
internal class ServerSessionRegistry {
14+
15+
private val logger = KotlinLogging.logger {}
16+
17+
/**
18+
* Atomic variable used to maintain a thread-safe registry of sessions.
19+
* Stores a persistent map where each session is identified by its unique key.
20+
*/
21+
private val registry = atomic(persistentMapOf<String, ServerSession>())
22+
23+
/**
24+
* Returns a read-only view of the current server sessions.
25+
*/
26+
internal val sessions: Map<ServerSessionKey, ServerSession>
27+
get() = registry.value
28+
29+
/**
30+
* Returns a server session by its ID.
31+
* @param sessionId The ID of the session to retrieve.
32+
* @throws IllegalArgumentException If the session doesn't exist.
33+
*/
34+
internal fun getSession(sessionId: ServerSessionKey): ServerSession =
35+
sessions[sessionId] ?: throw IllegalArgumentException("Session not found: $sessionId")
36+
37+
/**
38+
* Returns a server session by its ID, or null if it doesn't exist.
39+
* @param sessionId The ID of the session to retrieve.
40+
*/
41+
internal fun getSessionOrNull(sessionId: ServerSessionKey): ServerSession? = sessions[sessionId]
42+
43+
/**
44+
* Registers a server session.
45+
* @param session The session to register.
46+
*/
47+
internal fun addSession(session: ServerSession) {
48+
logger.info { "Adding session: ${session.sessionId}" }
49+
registry.update { sessions -> sessions.put(session.sessionId, session) }
50+
}
51+
52+
/**
53+
* Removes a server session by its ID.
54+
* @param sessionId The ID of the session to remove.
55+
*/
56+
internal fun removeSession(sessionId: ServerSessionKey) {
57+
logger.info { "Removing session: $sessionId" }
58+
registry.update { sessions -> sessions.remove(sessionId) }
59+
}
60+
}

0 commit comments

Comments
 (0)