@@ -7,14 +7,24 @@ import io.ktor.server.response.*
77import io.ktor.server.sse.*
88import io.modelcontextprotocol.kotlin.sdk.*
99import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
10+ import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SESSION_ID
1011import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
1112import kotlinx.serialization.encodeToString
13+ import kotlinx.serialization.json.JsonArray
14+ import kotlinx.serialization.json.JsonElement
15+ import kotlinx.serialization.json.JsonObject
16+ import kotlinx.serialization.json.decodeFromJsonElement
1217import kotlin.collections.HashMap
1318import kotlin.concurrent.atomics.AtomicBoolean
1419import kotlin.concurrent.atomics.ExperimentalAtomicApi
1520import kotlin.uuid.ExperimentalUuidApi
1621import kotlin.uuid.Uuid
1722
23+ /* *
24+ * Server transport for StreamableHttp: this allows server to respond to GET, POST and DELETE requests. Server can optionally make use of Server-Sent Events (SSE) to stream multiple server messages.
25+ *
26+ * Creates a new StreamableHttp server transport.
27+ */
1828@OptIn(ExperimentalAtomicApi ::class )
1929public class StreamableHttpServerTransport (
2030 private val isStateful : Boolean = false ,
@@ -55,7 +65,8 @@ public class StreamableHttpServerTransport(
5565 }
5666
5767 val streamId = requestToStreamMapping[requestId] ? : error(" No connection established for request id $requestId " )
58- val correspondingStream = streamMapping[streamId] ? : error(" No connection established for request id $requestId " )
68+ val correspondingStream =
69+ streamMapping[streamId] ? : error(" No connection established for request id $requestId " )
5970 val correspondingCall = callMapping[streamId] ? : error(" No connection established for request id $requestId " )
6071
6172 if (! enableJSONResponse) {
@@ -66,32 +77,33 @@ public class StreamableHttpServerTransport(
6677 }
6778
6879 requestResponseMapping[requestId] = message
69- val relatedIds = requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
80+ val relatedIds =
81+ requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
7082 val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null }
7183
72- if (allResponsesReady) {
73- if (enableJSONResponse) {
74- correspondingCall.response.headers.append(ContentType .toString(), ContentType .Application .Json .toString())
75- correspondingCall.response.status(HttpStatusCode .OK )
76- if (sessionId != null ) {
77- correspondingCall.response.header(" Mcp-Session-Id" , sessionId!! )
78- }
79- val responses = relatedIds.map{ requestResponseMapping[it] }
80- if (responses.size == 1 ) {
81- correspondingCall.respond(responses[0 ]!! )
82- } else {
83- correspondingCall.respond(responses)
84- }
85- callMapping.remove(streamId)
84+ if (! allResponsesReady) return
85+
86+ if (enableJSONResponse) {
87+ correspondingCall.response.headers.append(ContentType .toString(), ContentType .Application .Json .toString())
88+ correspondingCall.response.status(HttpStatusCode .OK )
89+ if (sessionId != null ) {
90+ correspondingCall.response.header(MCP_SESSION_ID , sessionId!! )
91+ }
92+ val responses = relatedIds.map { requestResponseMapping[it] }
93+ if (responses.size == 1 ) {
94+ correspondingCall.respond(responses[0 ]!! )
8695 } else {
87- correspondingStream.close()
88- streamMapping.remove(streamId)
96+ correspondingCall.respond(responses)
8997 }
98+ callMapping.remove(streamId)
99+ } else {
100+ correspondingStream.close()
101+ streamMapping.remove(streamId)
102+ }
90103
91- for (id in relatedIds) {
92- requestToStreamMapping.remove(id)
93- requestResponseMapping.remove(id)
94- }
104+ for (id in relatedIds) {
105+ requestToStreamMapping.remove(id)
106+ requestResponseMapping.remove(id)
95107 }
96108
97109 }
@@ -110,47 +122,13 @@ public class StreamableHttpServerTransport(
110122 @OptIn(ExperimentalUuidApi ::class )
111123 public suspend fun handlePostRequest (call : ApplicationCall , session : ServerSSESession ) {
112124 try {
113- val acceptHeader = call.request.headers[ " Accept " ]?.split( " , " ) ? : listOf ()
125+ if ( ! validateHeaders(call)) return
114126
115- if (! acceptHeader.contains(" text/event-stream" ) || ! acceptHeader.contains(" application/json" )) {
116- call.response.status(HttpStatusCode .NotAcceptable )
117- call.respond(
118- JSONRPCResponse (
119- id = null ,
120- error = JSONRPCError (
121- code = ErrorCode .Unknown (- 32000 ),
122- message = " Not Acceptable: Client must accept both application/json and text/event-stream"
123- )
124- )
125- )
126- return
127- }
127+ val messages = parseBody(call)
128128
129- val contentType = call.request.contentType()
130- if (contentType != ContentType .Application .Json ) {
131- call.response.status(HttpStatusCode .UnsupportedMediaType )
132- call.respond(
133- JSONRPCResponse (
134- id = null ,
135- error = JSONRPCError (
136- code = ErrorCode .Unknown (- 32000 ),
137- message = " Unsupported Media Type: Content-Type must be application/json"
138- )
139- )
140- )
141- return
142- }
143-
144- val body = call.receiveText()
145- val messages = mutableListOf<JSONRPCMessage >()
146-
147- if (body.startsWith(" [" )) {
148- messages.addAll(McpJson .decodeFromString<List <JSONRPCMessage >>(body))
149- } else {
150- messages.add(McpJson .decodeFromString(body))
151- }
129+ if (messages.isEmpty()) return
152130
153- val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == " initialize " }
131+ val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == Method . Defined . Initialize .value }
154132 if (hasInitializationRequest) {
155133 if (initialized.load() && sessionId != null ) {
156134 call.response.status(HttpStatusCode .BadRequest )
@@ -184,38 +162,37 @@ public class StreamableHttpServerTransport(
184162 sessionId = Uuid .random().toString()
185163 }
186164 initialized.store(true )
165+ }
187166
188- if (! validateSession(call)) {
189- return
190- }
191-
192- val hasRequests = messages.any { it is JSONRPCRequest }
193- val streamId = Uuid .random().toString()
167+ if (! validateSession(call)) {
168+ return
169+ }
194170
195- if (! hasRequests){
196- call.respondNullable(HttpStatusCode .Accepted )
197- } else {
198- if (! enableJSONResponse) {
199- call.response.headers.append(ContentType .toString(), ContentType .Text .EventStream .toString())
171+ val hasRequests = messages.any { it is JSONRPCRequest }
172+ val streamId = Uuid .random().toString()
200173
201- if (sessionId != null ) {
202- call.response.header(" Mcp-Session-Id" , sessionId!! )
203- }
204- }
174+ if (! hasRequests) {
175+ call.respondNullable(HttpStatusCode .Accepted )
176+ } else {
177+ if (! enableJSONResponse) {
178+ call.response.headers.append(ContentType .toString(), ContentType .Text .EventStream .toString())
205179
206- for (message in messages) {
207- if (message is JSONRPCRequest ) {
208- streamMapping[streamId] = session
209- callMapping[streamId] = call
210- requestToStreamMapping[message.id] = streamId
211- }
180+ if (sessionId != null ) {
181+ call.response.header(MCP_SESSION_ID , sessionId!! )
212182 }
213183 }
184+
214185 for (message in messages) {
215- _onMessage .invoke(message)
186+ if (message is JSONRPCRequest ) {
187+ streamMapping[streamId] = session
188+ callMapping[streamId] = call
189+ requestToStreamMapping[message.id] = streamId
190+ }
216191 }
217192 }
218-
193+ for (message in messages) {
194+ _onMessage .invoke(message)
195+ }
219196 } catch (e: Exception ) {
220197 call.response.status(HttpStatusCode .BadRequest )
221198 call.respond(
@@ -251,7 +228,7 @@ public class StreamableHttpServerTransport(
251228 }
252229
253230 if (sessionId != null ) {
254- call.response.header(" Mcp-Session-Id " , sessionId!! )
231+ call.response.header(MCP_SESSION_ID , sessionId!! )
255232 }
256233
257234 if (streamMapping[standalone] != null ) {
@@ -281,7 +258,7 @@ public class StreamableHttpServerTransport(
281258 call.respondNullable(HttpStatusCode .OK )
282259 }
283260
284- public suspend fun validateSession (call : ApplicationCall ): Boolean {
261+ private suspend fun validateSession (call : ApplicationCall ): Boolean {
285262 if (sessionId == null ) {
286263 return true
287264 }
@@ -301,4 +278,65 @@ public class StreamableHttpServerTransport(
301278 }
302279 return true
303280 }
281+
282+ private suspend fun validateHeaders (call : ApplicationCall ): Boolean {
283+ val acceptHeader = call.request.headers[" Accept" ]?.split(" ," ) ? : listOf ()
284+
285+ if (! acceptHeader.contains(" text/event-stream" ) || ! acceptHeader.contains(" application/json" )) {
286+ call.response.status(HttpStatusCode .NotAcceptable )
287+ call.respond(
288+ JSONRPCResponse (
289+ id = null ,
290+ error = JSONRPCError (
291+ code = ErrorCode .Unknown (- 32000 ),
292+ message = " Not Acceptable: Client must accept both application/json and text/event-stream"
293+ )
294+ )
295+ )
296+ return false
297+ }
298+
299+ val contentType = call.request.contentType()
300+ if (contentType != ContentType .Application .Json ) {
301+ call.response.status(HttpStatusCode .UnsupportedMediaType )
302+ call.respond(
303+ JSONRPCResponse (
304+ id = null ,
305+ error = JSONRPCError (
306+ code = ErrorCode .Unknown (- 32000 ),
307+ message = " Unsupported Media Type: Content-Type must be application/json"
308+ )
309+ )
310+ )
311+ return false
312+ }
313+
314+ return true
315+ }
316+
317+ private suspend fun parseBody (
318+ call : ApplicationCall ,
319+ ): List <JSONRPCMessage > {
320+ val messages = mutableListOf<JSONRPCMessage >()
321+ when (val body = call.receive<JsonElement >()) {
322+ is JsonObject -> messages.add(McpJson .decodeFromJsonElement(body))
323+ is JsonArray -> messages.addAll(McpJson .decodeFromJsonElement<List <JSONRPCMessage >>(body))
324+ else -> {
325+ call.response.status(HttpStatusCode .BadRequest )
326+ call.respond(
327+ JSONRPCResponse (
328+ id = null ,
329+ error = JSONRPCError (
330+ code = ErrorCode .Defined .InvalidRequest ,
331+ message = " Invalid Request: Server already initialized"
332+ )
333+ )
334+ )
335+ return listOf ()
336+ }
337+ }
338+ return messages
339+ }
340+
341+
304342}
0 commit comments