diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 8ee5af28..707ec7cd 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -1,8 +1,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { - public static final fun MCP (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V - public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V - public static final fun mcp (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V - public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function1;)V + public static final fun MCP (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function2;)V + public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function2;)V + public static final fun mcp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function2;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 934ba049..9507bb37 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -5,10 +5,9 @@ import io.ktor.http.HttpStatusCode import io.ktor.server.application.Application import io.ktor.server.application.install import io.ktor.server.response.respond -import io.ktor.server.routing.Routing +import io.ktor.server.routing.Route import io.ktor.server.routing.RoutingContext import io.ktor.server.routing.post -import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession @@ -36,18 +35,11 @@ internal class SseTransportManager(transports: Map = } } -@KtorDsl -public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) { - route(path) { - mcp(block) - } -} - /* * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE). */ @KtorDsl -public fun Routing.mcp(block: ServerSSESession.() -> Server) { +public fun Route.mcp(block: suspend ServerSSESession.() -> Server) { val sseTransportManager = SseTransportManager() sse { @@ -61,12 +53,12 @@ public fun Routing.mcp(block: ServerSSESession.() -> Server) { @Suppress("FunctionName") @Deprecated("Use mcp() instead", ReplaceWith("mcp(block)"), DeprecationLevel.ERROR) -public fun Application.MCP(block: ServerSSESession.() -> Server) { +public fun Application.MCP(block: suspend ServerSSESession.() -> Server) { mcp(block) } @KtorDsl -public fun Application.mcp(block: ServerSSESession.() -> Server) { +public fun Application.mcp(block: suspend ServerSSESession.() -> Server) { install(SSE) routing { @@ -77,7 +69,7 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) { internal suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, sseTransportManager: SseTransportManager, - block: ServerSSESession.() -> Server, + block: suspend ServerSSESession.() -> Server, ) { val transport = mcpSseTransport(postEndpoint, sseTransportManager) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractSseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractSseIntegrationTest.kt new file mode 100644 index 00000000..5c13f9e4 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractSseIntegrationTest.kt @@ -0,0 +1,198 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlin.test.Test +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds +import io.ktor.client.engine.cio.CIO as ClientCIO + +typealias CIOEmbeddedServer = EmbeddedServer + +abstract class AbstractSseIntegrationTest { + @Test + fun `client should be able to connect to sse server`() = runTest(timeout = 5.seconds) { + var server: CIOEmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + val (s, path) = initServer() + server = s + + val port = server.engine.resolvedConnectors().first().port + client = initClient(serverPort = port, path = path) + } + } finally { + client?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open SSE from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single sse connection`() = runTest(timeout = 5.seconds) { + var server: CIOEmbeddedServer? = null + var client: Client? = null + try { + withContext(Dispatchers.Default) { + val (s, path) = initServer() + server = s + + val port = server.engine.resolvedConnectors().first().port + client = initClient("Client A", port, path) + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } finally { + client?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open SSE connection #1 from Client A and note the sessionId= value. + * 2. Open SSE connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple sse connections`() = runTest(timeout = 5.seconds) { + var server: CIOEmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + val (s, path) = initServer() + server = s + val port = server.engine.resolvedConnectors().first().port + + clientA = initClient("Client A", port, path) + clientB = initClient("Client B", port, path) + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } finally { + clientA?.close() + clientB?.close() + server?.stopSuspend(1000, 2000) + } + } + + private suspend fun initClient(name: String = "", serverPort: Int, path: List): Client { + val client = Client( + Implementation(name = name, version = "1.0.0"), + ) + + val httpClient = HttpClient(ClientCIO) { + install(SSE) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpSseTransport { + url { + host = URL + port = serverPort + pathSegments = path + } + } + + client.connect(transport) + + return client + } + + /** + * Create initialise the webserver for testing. + * Concrete test classes implement this. + */ + protected abstract suspend fun initServer(): Pair> + + /** + * Construct a new instance of the mcp server under test + */ + protected fun newMcpServer(): Server { + val server = Server( + Implementation(name = "sse-server", version = "1.0.0"), + ServerOptions( + capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), + ), + ) + + server.addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true, + ), + ), + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}"), + ), + ), + ) + } + return server + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName), + ), + ) + + return (response.messages.first().content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } + + companion object { + protected const val URL = "127.0.0.1" + protected const val PORT = 0 + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 31b0fd0d..85a7170d 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -1,167 +1,17 @@ package io.modelcontextprotocol.kotlin.sdk.integration -import io.ktor.client.HttpClient -import io.ktor.client.plugins.sse.SSE import io.ktor.server.application.install -import io.ktor.server.cio.CIOApplicationEngine -import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.routing -import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest -import io.modelcontextprotocol.kotlin.sdk.GetPromptResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptMessage -import io.modelcontextprotocol.kotlin.sdk.Role -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.TextContent -import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport -import io.modelcontextprotocol.kotlin.sdk.server.Server -import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.mcp -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.withContext -import kotlin.test.Test -import kotlin.test.assertTrue -import kotlin.time.Duration.Companion.seconds -import io.ktor.client.engine.cio.CIO as ClientCIO +import kotlin.collections.emptyList import io.ktor.server.cio.CIO as ServerCIO import io.ktor.server.sse.SSE as ServerSSE -class SseIntegrationTest { - @Test - fun `client should be able to connect to sse server`() = runTest(timeout = 5.seconds) { - var server: EmbeddedServer? = null - var client: Client? = null +class SseIntegrationTest : AbstractSseIntegrationTest() { - try { - withContext(Dispatchers.Default) { - server = initServer() - val port = server.engine.resolvedConnectors().first().port - client = initClient(serverPort = port) - } - } finally { - client?.close() - server?.stopSuspend(1000, 2000) - } - } - - /** - * Test Case #1: One opened connection, a client gets a prompt - * - * 1. Open SSE from Client A. - * 2. Send a POST request from Client A to POST /prompts/get. - * 3. Observe that Client A receives a response related to it. - */ - @Test - fun `single sse connection`() = runTest(timeout = 5.seconds) { - var server: EmbeddedServer? = null - var client: Client? = null - try { - withContext(Dispatchers.Default) { - server = initServer() - val port = server.engine.resolvedConnectors().first().port - client = initClient("Client A", port) - - val promptA = getPrompt(client, "Client A") - assertTrue { "Client A" in promptA } - } - } finally { - client?.close() - server?.stopSuspend(1000, 2000) - } - } - - /** - * Test Case #1: Two open connections, each client gets a client-specific prompt - * - * 1. Open SSE connection #1 from Client A and note the sessionId= value. - * 2. Open SSE connection #2 from Client B and note the sessionId= value. - * 3. Send a POST request to POST /message with the corresponding sessionId#1. - * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. - */ - @Test - fun `multiple sse connections`() = runTest(timeout = 5.seconds) { - var server: EmbeddedServer? = null - var clientA: Client? = null - var clientB: Client? = null - - try { - withContext(Dispatchers.Default) { - server = initServer() - val port = server.engine.resolvedConnectors().first().port - - clientA = initClient("Client A", port) - clientB = initClient("Client B", port) - - // Step 3: Send a prompt request from Client A - val promptA = getPrompt(clientA, "Client A") - // Step 4: Send a prompt request from Client B - val promptB = getPrompt(clientB, "Client B") - - assertTrue { "Client A" in promptA } - assertTrue { "Client B" in promptB } - } - } finally { - clientA?.close() - clientB?.close() - server?.stopSuspend(1000, 2000) - } - } - - private suspend fun initClient(name: String = "", serverPort: Int): Client { - val client = Client( - Implementation(name = name, version = "1.0.0"), - ) - - val httpClient = HttpClient(ClientCIO) { - install(SSE) - } - - // Create a transport wrapper that captures the session ID and received messages - val transport = httpClient.mcpSseTransport { - url { - host = URL - port = serverPort - } - } - - client.connect(transport) - - return client - } - - private suspend fun initServer(): EmbeddedServer { - val server = Server( - Implementation(name = "sse-server", version = "1.0.0"), - ServerOptions( - capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), - ), - ) - - server.addPrompt( - name = "prompt", - description = "Prompt description", - arguments = listOf( - PromptArgument( - name = "client", - description = "Client name who requested a prompt", - required = true, - ), - ), - ) { request -> - GetPromptResult( - "Prompt for ${request.name}", - messages = listOf( - PromptMessage( - role = Role.user, - content = TextContent("Prompt for client ${request.arguments?.get("client")}"), - ), - ), - ) - } + protected override suspend fun initServer(): Pair> { + val server = newMcpServer() val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { install(ServerSSE) @@ -170,26 +20,6 @@ class SseIntegrationTest { } } - return ktorServer.startSuspend(wait = false) - } - - /** - * Retrieves a prompt result using the provided client and client name. - */ - private suspend fun getPrompt(client: Client, clientName: String): String { - val response = client.getPrompt( - GetPromptRequest( - "prompt", - arguments = mapOf("client" to clientName), - ), - ) - - return (response.messages.first().content as? TextContent)?.text - ?: error("Failed to receive prompt for Client $clientName") - } - - companion object { - private const val URL = "127.0.0.1" - private const val PORT = 0 + return ktorServer.startSuspend(wait = false) to emptyList() } } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithPath.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithPath.kt new file mode 100644 index 00000000..edf75cb9 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithPath.kt @@ -0,0 +1,27 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.server.application.install +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.route +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +class SseIntegrationTestWithPath : AbstractSseIntegrationTest() { + + protected override suspend fun initServer(): Pair> { + val server = newMcpServer() + + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerSSE) + routing { + route("/some-path") { + mcp { server } + } + } + } + + return ktorServer.startSuspend(wait = false) to listOf("some-path") + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithPathAndSuspend.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithPathAndSuspend.kt new file mode 100644 index 00000000..bfe15d98 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithPathAndSuspend.kt @@ -0,0 +1,30 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.server.application.install +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.route +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +class SseIntegrationTestWithPathAndSuspend : AbstractSseIntegrationTest() { + + private suspend fun suspendNewMcpServer(): Server = newMcpServer() + + protected override suspend fun initServer(): Pair> { + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerSSE) + routing { + route("/some-path") { + mcp { + suspendNewMcpServer() + } + } + } + } + + return ktorServer.startSuspend(wait = false) to listOf("some-path") + } +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithSuspend.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithSuspend.kt new file mode 100644 index 00000000..bc962a85 --- /dev/null +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTestWithSuspend.kt @@ -0,0 +1,28 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.ktor.server.application.install +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import kotlin.collections.emptyList +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +class SseIntegrationTestWithSuspend : AbstractSseIntegrationTest() { + + private suspend fun suspendNewMcpServer(): Server = newMcpServer() + + protected override suspend fun initServer(): Pair> { + val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { + install(ServerSSE) + routing { + mcp { + suspendNewMcpServer() + } + } + } + + return ktorServer.startSuspend(wait = false) to emptyList() + } +}