Skip to content

Commit 4b4c06c

Browse files
committed
SSE transport for server does not work inside custom Ktor's Route #94
This should also fix Issues #236 and #237
1 parent ba18b2b commit 4b4c06c

File tree

7 files changed

+302
-190
lines changed

7 files changed

+302
-190
lines changed

kotlin-sdk-server/api/kotlin-sdk-server.api

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt {
22
public static final fun MCP (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V
33
public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V
4-
public static final fun mcp (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V
5-
public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function1;)V
4+
public static final fun mcp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;)V
65
}
76

87
public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt {

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@ import io.ktor.http.HttpStatusCode
55
import io.ktor.server.application.Application
66
import io.ktor.server.application.install
77
import io.ktor.server.response.respond
8-
import io.ktor.server.routing.Routing
8+
import io.ktor.server.routing.Route
99
import io.ktor.server.routing.RoutingContext
1010
import io.ktor.server.routing.post
11-
import io.ktor.server.routing.route
1211
import io.ktor.server.routing.routing
1312
import io.ktor.server.sse.SSE
1413
import io.ktor.server.sse.ServerSSESession
@@ -36,18 +35,11 @@ internal class SseTransportManager(transports: Map<String, SseServerTransport> =
3635
}
3736
}
3837

39-
@KtorDsl
40-
public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) {
41-
route(path) {
42-
mcp(block)
43-
}
44-
}
45-
4638
/*
4739
* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
4840
*/
4941
@KtorDsl
50-
public fun Routing.mcp(block: ServerSSESession.() -> Server) {
42+
public fun Route.mcp(block: suspend ServerSSESession.() -> Server) {
5143
val sseTransportManager = SseTransportManager()
5244

5345
sse {
@@ -61,12 +53,12 @@ public fun Routing.mcp(block: ServerSSESession.() -> Server) {
6153

6254
@Suppress("FunctionName")
6355
@Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.ERROR)
64-
public fun Application.MCP(block: ServerSSESession.() -> Server) {
56+
public fun Application.MCP(block: suspend ServerSSESession.() -> Server) {
6557
mcp(block)
6658
}
6759

6860
@KtorDsl
69-
public fun Application.mcp(block: ServerSSESession.() -> Server) {
61+
public fun Application.mcp(block: suspend ServerSSESession.() -> Server) {
7062
install(SSE)
7163

7264
routing {
@@ -77,7 +69,7 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) {
7769
internal suspend fun ServerSSESession.mcpSseEndpoint(
7870
postEndpoint: String,
7971
sseTransportManager: SseTransportManager,
80-
block: ServerSSESession.() -> Server,
72+
block: suspend ServerSSESession.() -> Server,
8173
) {
8274
val transport = mcpSseTransport(postEndpoint, sseTransportManager)
8375

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
package io.modelcontextprotocol.kotlin.sdk.integration
2+
3+
import io.ktor.client.HttpClient
4+
import io.ktor.client.plugins.sse.SSE
5+
import io.ktor.server.cio.CIOApplicationEngine
6+
import io.ktor.server.engine.EmbeddedServer
7+
import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest
8+
import io.modelcontextprotocol.kotlin.sdk.GetPromptResult
9+
import io.modelcontextprotocol.kotlin.sdk.Implementation
10+
import io.modelcontextprotocol.kotlin.sdk.PromptArgument
11+
import io.modelcontextprotocol.kotlin.sdk.PromptMessage
12+
import io.modelcontextprotocol.kotlin.sdk.Role
13+
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
14+
import io.modelcontextprotocol.kotlin.sdk.TextContent
15+
import io.modelcontextprotocol.kotlin.sdk.client.Client
16+
import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport
17+
import io.modelcontextprotocol.kotlin.sdk.server.Server
18+
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
19+
import kotlinx.coroutines.Dispatchers
20+
import kotlinx.coroutines.test.runTest
21+
import kotlinx.coroutines.withContext
22+
import kotlin.test.Test
23+
import kotlin.test.assertTrue
24+
import kotlin.time.Duration.Companion.seconds
25+
import io.ktor.client.engine.cio.CIO as ClientCIO
26+
27+
typealias CIOEmbeddedServer = EmbeddedServer<CIOApplicationEngine, CIOApplicationEngine.Configuration>
28+
29+
abstract class AbstractSseIntegrationTest {
30+
@Test
31+
fun `client should be able to connect to sse server`() = runTest(timeout = 5.seconds) {
32+
var server: CIOEmbeddedServer? = null
33+
var client: Client? = null
34+
35+
try {
36+
withContext(Dispatchers.Default) {
37+
val (s, path) = initServer()
38+
server = s
39+
40+
val port = server.engine.resolvedConnectors().first().port
41+
client = initClient(serverPort = port, path = path)
42+
}
43+
} finally {
44+
client?.close()
45+
server?.stopSuspend(1000, 2000)
46+
}
47+
}
48+
49+
/**
50+
* Test Case #1: One opened connection, a client gets a prompt
51+
*
52+
* 1. Open SSE from Client A.
53+
* 2. Send a POST request from Client A to POST /prompts/get.
54+
* 3. Observe that Client A receives a response related to it.
55+
*/
56+
@Test
57+
fun `single sse connection`() = runTest(timeout = 5.seconds) {
58+
var server: CIOEmbeddedServer? = null
59+
var client: Client? = null
60+
try {
61+
withContext(Dispatchers.Default) {
62+
val (s, path) = initServer()
63+
server = s
64+
65+
val port = server.engine.resolvedConnectors().first().port
66+
client = initClient("Client A", port, path)
67+
68+
val promptA = getPrompt(client, "Client A")
69+
assertTrue { "Client A" in promptA }
70+
}
71+
} finally {
72+
client?.close()
73+
server?.stopSuspend(1000, 2000)
74+
}
75+
}
76+
77+
/**
78+
* Test Case #1: Two open connections, each client gets a client-specific prompt
79+
*
80+
* 1. Open SSE connection #1 from Client A and note the sessionId=<sessionId#1> value.
81+
* 2. Open SSE connection #2 from Client B and note the sessionId=<sessionId#2> value.
82+
* 3. Send a POST request to POST /message with the corresponding sessionId#1.
83+
* 4. Observe that Client B (connection #2) receives a response related to sessionId#1.
84+
*/
85+
@Test
86+
fun `multiple sse connections`() = runTest(timeout = 5.seconds) {
87+
var server: CIOEmbeddedServer? = null
88+
var clientA: Client? = null
89+
var clientB: Client? = null
90+
91+
try {
92+
withContext(Dispatchers.Default) {
93+
val (s, path) = initServer()
94+
server = s
95+
val port = server.engine.resolvedConnectors().first().port
96+
97+
clientA = initClient("Client A", port, path)
98+
clientB = initClient("Client B", port, path)
99+
100+
// Step 3: Send a prompt request from Client A
101+
val promptA = getPrompt(clientA, "Client A")
102+
// Step 4: Send a prompt request from Client B
103+
val promptB = getPrompt(clientB, "Client B")
104+
105+
assertTrue { "Client A" in promptA }
106+
assertTrue { "Client B" in promptB }
107+
}
108+
} finally {
109+
clientA?.close()
110+
clientB?.close()
111+
server?.stopSuspend(1000, 2000)
112+
}
113+
}
114+
115+
private suspend fun initClient(name: String = "", serverPort: Int, path: List<String>): Client {
116+
val client = Client(
117+
Implementation(name = name, version = "1.0.0"),
118+
)
119+
120+
val httpClient = HttpClient(ClientCIO) {
121+
install(SSE)
122+
}
123+
124+
// Create a transport wrapper that captures the session ID and received messages
125+
val transport = httpClient.mcpSseTransport {
126+
url {
127+
host = URL
128+
port = serverPort
129+
pathSegments = path
130+
}
131+
}
132+
133+
client.connect(transport)
134+
135+
return client
136+
}
137+
138+
/**
139+
* Create initialise the webserver for testing.
140+
* Concrete test classes implement this.
141+
*/
142+
protected abstract suspend fun initServer(): Pair<CIOEmbeddedServer, List<String>>
143+
144+
/**
145+
* Construct a new instance of the mcp server under test
146+
*/
147+
protected fun newMcpServer(): Server {
148+
val server = Server(
149+
Implementation(name = "sse-server", version = "1.0.0"),
150+
ServerOptions(
151+
capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)),
152+
),
153+
)
154+
155+
server.addPrompt(
156+
name = "prompt",
157+
description = "Prompt description",
158+
arguments = listOf(
159+
PromptArgument(
160+
name = "client",
161+
description = "Client name who requested a prompt",
162+
required = true,
163+
),
164+
),
165+
) { request ->
166+
GetPromptResult(
167+
"Prompt for ${request.name}",
168+
messages = listOf(
169+
PromptMessage(
170+
role = Role.user,
171+
content = TextContent("Prompt for client ${request.arguments?.get("client")}"),
172+
),
173+
),
174+
)
175+
}
176+
return server
177+
}
178+
179+
/**
180+
* Retrieves a prompt result using the provided client and client name.
181+
*/
182+
private suspend fun getPrompt(client: Client, clientName: String): String {
183+
val response = client.getPrompt(
184+
GetPromptRequest(
185+
"prompt",
186+
arguments = mapOf("client" to clientName),
187+
),
188+
)
189+
190+
return (response.messages.first().content as? TextContent)?.text
191+
?: error("Failed to receive prompt for Client $clientName")
192+
}
193+
194+
companion object {
195+
protected const val URL = "127.0.0.1"
196+
protected const val PORT = 0
197+
}
198+
}

0 commit comments

Comments
 (0)