Skip to content

Commit b01e472

Browse files
committed
update from comments
1 parent cfc188e commit b01e472

File tree

5 files changed

+39
-38
lines changed

5 files changed

+39
-38
lines changed

kotlin-sdk-core/api/kotlin-sdk-core.api

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,16 +1071,18 @@ public final class io/modelcontextprotocol/kotlin/sdk/InitializedNotification$Pa
10711071

10721072
public final class io/modelcontextprotocol/kotlin/sdk/JSONRPCError : io/modelcontextprotocol/kotlin/sdk/JSONRPCMessage {
10731073
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError$Companion;
1074-
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)V
1075-
public synthetic fun <init> (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
1076-
public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;
1077-
public final fun component2 ()Ljava/lang/String;
1078-
public final fun component3 ()Lkotlinx/serialization/json/JsonObject;
1079-
public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;
1080-
public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;
1074+
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)V
1075+
public synthetic fun <init> (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
1076+
public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/RequestId;
1077+
public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;
1078+
public final fun component3 ()Ljava/lang/String;
1079+
public final fun component4 ()Lkotlinx/serialization/json/JsonObject;
1080+
public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;
1081+
public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;Lio/modelcontextprotocol/kotlin/sdk/RequestId;Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/JSONRPCError;
10811082
public fun equals (Ljava/lang/Object;)Z
10821083
public final fun getCode ()Lio/modelcontextprotocol/kotlin/sdk/ErrorCode;
10831084
public final fun getData ()Lkotlinx/serialization/json/JsonObject;
1085+
public final fun getId ()Lio/modelcontextprotocol/kotlin/sdk/RequestId;
10841086
public final fun getMessage ()Ljava/lang/String;
10851087
public fun hashCode ()I
10861088
public fun toString ()Ljava/lang/String;

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,14 @@ public data class JSONRPCNotification(
249249
*/
250250
@Serializable
251251
public class JSONRPCResponse(
252-
public val id: RequestId?,
252+
public val id: RequestId,
253253
public val jsonrpc: String = JSONRPC_VERSION,
254254
public val result: RequestResult? = null,
255255
public val error: JSONRPCError? = null,
256256
) : JSONRPCMessage {
257257

258258
public fun copy(
259-
id: RequestId? = this.id,
259+
id: RequestId = this.id,
260260
jsonrpc: String = this.jsonrpc,
261261
result: RequestResult? = this.result,
262262
error: JSONRPCError? = this.error,

kotlin-sdk-server/api/kotlin-sdk-server.api

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
public final class io/modelcontextprotocol/kotlin/sdk/LibVersionKt {
2-
public static final field LIB_VERSION Ljava/lang/String;
3-
}
4-
51
public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventStore {
62
public abstract fun replayEventsAfter (Ljava/lang/String;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
73
public abstract fun storeEvent (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,26 @@ import io.ktor.server.routing.routing
1414
import io.ktor.server.sse.SSE
1515
import io.ktor.server.sse.ServerSSESession
1616
import io.ktor.server.sse.sse
17-
import io.ktor.util.collections.ConcurrentMap
1817
import io.ktor.utils.io.KtorDsl
1918
import kotlinx.atomicfu.AtomicRef
2019
import kotlinx.atomicfu.atomic
2120
import kotlinx.atomicfu.update
2221
import kotlinx.collections.immutable.PersistentMap
2322
import kotlinx.collections.immutable.toPersistentMap
2423
import io.modelcontextprotocol.kotlin.sdk.ErrorCode
24+
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
2525

2626
private val logger = KotlinLogging.logger {}
2727

28-
internal class SseTransportManager(transports: Map<String, SseServerTransport> = emptyMap()) {
29-
private val transports: AtomicRef<PersistentMap<String, SseServerTransport>> = atomic(transports.toPersistentMap())
28+
internal class TransportManager(transports: Map<String, AbstractTransport> = emptyMap()) {
29+
private val transports: AtomicRef<PersistentMap<String, AbstractTransport>> = atomic(transports.toPersistentMap())
3030

31-
fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId]
31+
fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId)
3232

33-
fun addTransport(transport: SseServerTransport) {
34-
transports.update { it.put(transport.sessionId, transport) }
33+
fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId]
34+
35+
fun addTransport(sessionId: String, transport: AbstractTransport) {
36+
transports.update { it.put(sessionId, transport) }
3537
}
3638

3739
fun removeTransport(sessionId: String) {
@@ -51,14 +53,14 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) {
5153
*/
5254
@KtorDsl
5355
public fun Routing.mcp(block: ServerSSESession.() -> Server) {
54-
val sseTransportManager = SseTransportManager()
56+
val transportManager = TransportManager()
5557

5658
sse {
57-
mcpSseEndpoint("", sseTransportManager, block)
59+
mcpSseEndpoint("", transportManager, block)
5860
}
5961

6062
post {
61-
mcpPostEndpoint(sseTransportManager)
63+
mcpPostEndpoint(transportManager)
6264
}
6365
}
6466

@@ -85,12 +87,12 @@ public fun Application.mcpStreamableHttp(
8587
eventStore: EventStore? = null,
8688
block: RoutingContext.() -> Server,
8789
) {
88-
val transports = ConcurrentMap<String, StreamableHttpServerTransport>()
90+
val transportManager = TransportManager()
8991

9092
routing {
9193
post("/mcp") {
9294
mcpStreamableHttpEndpoint(
93-
transports,
95+
transportManager,
9496
enableDnsRebindingProtection,
9597
allowedHosts,
9698
allowedOrigins,
@@ -124,16 +126,16 @@ public fun Application.mcpStatelessStreamableHttp(
124126

125127
private suspend fun ServerSSESession.mcpSseEndpoint(
126128
postEndpoint: String,
127-
sseTransportManager: SseTransportManager,
129+
transportManager: TransportManager,
128130
block: ServerSSESession.() -> Server,
129131
) {
130-
val transport = mcpSseTransport(postEndpoint, sseTransportManager)
132+
val transport = mcpSseTransport(postEndpoint, transportManager)
131133

132134
val server = block()
133135

134136
server.onClose {
135137
logger.info { "Server connection closed for sessionId: ${transport.sessionId}" }
136-
sseTransportManager.removeTransport(transport.sessionId)
138+
transportManager.removeTransport(transport.sessionId)
137139
}
138140

139141
server.connect(transport)
@@ -143,26 +145,26 @@ private suspend fun ServerSSESession.mcpSseEndpoint(
143145

144146
internal fun ServerSSESession.mcpSseTransport(
145147
postEndpoint: String,
146-
sseTransportManager: SseTransportManager,
148+
transportManager: TransportManager,
147149
): SseServerTransport {
148150
val transport = SseServerTransport(postEndpoint, this)
149-
sseTransportManager.addTransport(transport)
151+
transportManager.addTransport(transport.sessionId, transport)
150152
logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
151153

152154
return transport
153155
}
154156

155157
internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
156-
transports: ConcurrentMap<String, StreamableHttpServerTransport>,
158+
transportManager: TransportManager,
157159
enableDnsRebindingProtection: Boolean = false,
158160
allowedHosts: List<String>? = null,
159161
allowedOrigins: List<String>? = null,
160162
eventStore: EventStore? = null,
161163
block: RoutingContext.() -> Server,
162164
) {
163165
val sessionId = this.call.request.header(MCP_SESSION_ID_HEADER)
164-
val transport = if (sessionId != null && transports.containsKey(sessionId)) {
165-
transports[sessionId]!!
166+
val transport = if (sessionId != null && transportManager.hasTransport(sessionId)) {
167+
transportManager.getTransport(sessionId)
166168
} else if (sessionId == null) {
167169
val transport = StreamableHttpServerTransport(
168170
enableDnsRebindingProtection = enableDnsRebindingProtection,
@@ -173,7 +175,7 @@ internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
173175
)
174176

175177
transport.setOnSessionInitialized { sessionId ->
176-
transports[sessionId] = transport
178+
transportManager.addTransport(sessionId, transport)
177179

178180
logger.info { "New StreamableHttp connection established and stored with sessionId: $sessionId" }
179181
}
@@ -199,7 +201,7 @@ internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
199201
return
200202
}
201203

202-
transport.handleRequest(null, this.call)
204+
(transport as StreamableHttpServerTransport).handleRequest(null, this.call)
203205
logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" }
204206
}
205207

@@ -234,15 +236,15 @@ internal suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint(
234236
logger.debug { "Server connected to transport without sessionId" }
235237
}
236238

237-
internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTransportManager) {
239+
internal suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) {
238240
val sessionId: String = call.request.queryParameters["sessionId"] ?: run {
239241
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
240242
return
241243
}
242244

243245
logger.debug { "Received message for sessionId: $sessionId" }
244246

245-
val transport = sseTransportManager.getTransport(sessionId)
247+
val transport = transportManager.getTransport(sessionId) as SseServerTransport?
246248
if (transport == null) {
247249
logger.warn { "Session not found for sessionId: $sessionId" }
248250
call.respond(HttpStatusCode.NotFound, "Session not found")

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import io.ktor.server.request.httpMethod
1212
import io.ktor.server.request.receiveText
1313
import io.ktor.server.response.header
1414
import io.ktor.server.response.respond
15+
import io.ktor.server.response.respondBytes
1516
import io.ktor.server.response.respondNullable
1617
import io.ktor.server.sse.ServerSSESession
1718
import io.ktor.util.collections.ConcurrentMap
@@ -332,7 +333,7 @@ public class StreamableHttpServerTransport(
332333

333334
val hasRequest = messages.any { it is JSONRPCRequest }
334335
if (!hasRequest) {
335-
call.respondNullable(status = HttpStatusCode.Accepted, message = null)
336+
call.respondBytes(status = HttpStatusCode.Accepted, bytes = ByteArray(0))
336337
messages.forEach { message -> _onMessage(message) }
337338
return
338339
}
@@ -568,7 +569,7 @@ internal suspend fun ApplicationCall.reject(status: HttpStatusCode, code: ErrorC
568569
this.response.status(status)
569570
this.respond(
570571
JSONRPCResponse(
571-
id = null,
572+
id = RequestId.StringId("server-error"),
572573
error = JSONRPCError(message = message, code = code),
573574
),
574575
)

0 commit comments

Comments
 (0)