@@ -36,9 +36,6 @@ import io.modelcontextprotocol.kotlin.sdk.ToolAnnotations
3636import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
3737import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
3838import io.modelcontextprotocol.kotlin.sdk.shared.Transport
39- import kotlinx.atomicfu.atomic
40- import kotlinx.atomicfu.update
41- import kotlinx.collections.immutable.persistentMapOf
4239import kotlinx.coroutines.CancellationException
4340import kotlinx.serialization.json.JsonObject
4441
@@ -88,24 +85,13 @@ public open class Server(
8885 block: Server.() -> Unit = {},
8986 ) : this(serverInfo, options, { instructions }, block)
9087
91- private val sessionRegistry = atomic(persistentMapOf<String, ServerSession>() )
88+ private val sessionRegistry = ServerSessionRegistry( )
9289
9390 /**
94- * Returns a read-only view of the current server sessions.
91+ * Provides a snapshot of all sessions currently registered in the server
9592 */
96- public val sessions: Map<String, ServerSession>
97- get() = sessionRegistry.value
98-
99- /**
100- * Gets a server session by its ID.
101- */
102- public fun getSession(sessionId: String): ServerSession? = sessions[sessionId]
103-
104- /**
105- * Gets a server session by its ID or throws an exception if the session doesn't exist.
106- */
107- public fun getSessionOrThrow(sessionId: String): ServerSession =
108- sessions[sessionId] ?: throw IllegalArgumentException("Session not found: $sessionId")
93+ public val sessions: Map<ServerSessionKey, ServerSession>
94+ get() = sessionRegistry.sessions
10995
11096 @Suppress("ktlint:standard:backing-property-naming")
11197 private var _onInitialized: (() -> Unit) = {}
@@ -200,12 +186,12 @@ public open class Server(
200186 // Register cleanup handler to remove session from list when it closes
201187 session.onClose {
202188 logger.debug { "Removing closed session from active sessions list" }
203- sessionRegistry.update { sessions -> sessions.remove (session.sessionId) }
189+ sessionRegistry.removeSession (session.sessionId)
204190 }
205191 logger.debug { "Server session connecting to transport" }
206192 session.connect(transport)
207193 logger.debug { "Server session successfully connected to transport" }
208- sessionRegistry.update { sessions -> sessions.put (session.sessionId, session) }
194+ sessionRegistry.addSession (session)
209195
210196 _onConnect()
211197 return session
@@ -571,9 +557,8 @@ public open class Server(
571557 * Triggers [ServerSession.ping] request for session by provided [sessionId].
572558 * @param sessionId The session ID to ping
573559 */
574- public suspend fun ping(sessionId: String): EmptyRequestResult {
575- val session = getSessionOrThrow(sessionId)
576- return session.ping()
560+ public suspend fun ping(sessionId: String): EmptyRequestResult = with(sessionRegistry.getSession(sessionId)) {
561+ ping()
577562 }
578563
579564 /**
@@ -589,9 +574,8 @@ public open class Server(
589574 sessionId: String,
590575 params: CreateMessageRequest,
591576 options: RequestOptions? = null,
592- ): CreateMessageResult {
593- val session = getSessionOrThrow(sessionId)
594- return session.request(params, options)
577+ ): CreateMessageResult = with(sessionRegistry.getSession(sessionId)) {
578+ request(params, options)
595579 }
596580
597581 /**
@@ -607,9 +591,8 @@ public open class Server(
607591 sessionId: String,
608592 params: JsonObject = EmptyJsonObject,
609593 options: RequestOptions? = null,
610- ): ListRootsResult {
611- val session = getSessionOrThrow(sessionId)
612- return session.listRoots(params, options)
594+ ): ListRootsResult = with(sessionRegistry.getSession(sessionId)) {
595+ listRoots(params, options)
613596 }
614597
615598 /**
@@ -627,9 +610,8 @@ public open class Server(
627610 message: String,
628611 requestedSchema: RequestedSchema,
629612 options: RequestOptions? = null,
630- ): CreateElicitationResult {
631- val session = getSessionOrThrow(sessionId)
632- return session.createElicitation(message, requestedSchema, options)
613+ ): CreateElicitationResult = with(sessionRegistry.getSession(sessionId)) {
614+ createElicitation(message, requestedSchema, options)
633615 }
634616
635617 /**
@@ -639,8 +621,9 @@ public open class Server(
639621 * @param notification The logging message notification.
640622 */
641623 public suspend fun sendLoggingMessage(sessionId: String, notification: LoggingMessageNotification) {
642- val session = getSessionOrThrow(sessionId)
643- session.sendLoggingMessage(notification)
624+ with(sessionRegistry.getSession(sessionId)) {
625+ sendLoggingMessage(notification)
626+ }
644627 }
645628
646629 /**
@@ -650,8 +633,9 @@ public open class Server(
650633 * @param notification Details of the updated resource.
651634 */
652635 public suspend fun sendResourceUpdated(sessionId: String, notification: ResourceUpdatedNotification) {
653- val session = getSessionOrThrow(sessionId)
654- session.sendResourceUpdated(notification)
636+ with(sessionRegistry.getSession(sessionId)) {
637+ sendResourceUpdated(notification)
638+ }
655639 }
656640
657641 /**
@@ -660,8 +644,9 @@ public open class Server(
660644 * @param sessionId The session ID to send the resource list changed notification to.
661645 */
662646 public suspend fun sendResourceListChanged(sessionId: String) {
663- val session = getSessionOrThrow(sessionId)
664- session.sendResourceListChanged()
647+ with(sessionRegistry.getSession(sessionId)) {
648+ sendResourceListChanged()
649+ }
665650 }
666651
667652 /**
@@ -670,8 +655,9 @@ public open class Server(
670655 * @param sessionId The session ID to send the tool list changed notification to.
671656 */
672657 public suspend fun sendToolListChanged(sessionId: String) {
673- val session = getSessionOrThrow(sessionId)
674- session.sendToolListChanged()
658+ with(sessionRegistry.getSession(sessionId)) {
659+ sendToolListChanged()
660+ }
675661 }
676662
677663 /**
@@ -680,8 +666,9 @@ public open class Server(
680666 * @param sessionId The session ID to send the prompt list changed notification to.
681667 */
682668 public suspend fun sendPromptListChanged(sessionId: String) {
683- val session = getSessionOrThrow(sessionId)
684- session.sendPromptListChanged()
669+ with(sessionRegistry.getSession(sessionId)) {
670+ sendPromptListChanged()
671+ }
685672 }
686673 // End the ServerSession redirection section
687674}
0 commit comments