diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 8ee5af28..9c4fded9 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -45,9 +45,10 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredTool { } public class io/modelcontextprotocol/kotlin/sdk/server/Server { - public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Ljava/lang/String;)V - public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function0;)V - public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function0;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun addPrompt (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V public final fun addPrompt (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;)V public static synthetic fun addPrompt$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Lkotlin/jvm/functions/Function2;ILjava/lang/Object;)V @@ -61,6 +62,7 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server { public final fun addTools (Ljava/util/List;)V public final fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun createSession (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; protected final fun getInstructionsProvider ()Lkotlin/jvm/functions/Function0; protected final fun getOptions ()Lio/modelcontextprotocol/kotlin/sdk/server/ServerOptions; public final fun getPrompts ()Ljava/util/Map; diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 934ba049..06e65284 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -88,7 +88,7 @@ internal suspend fun ServerSSESession.mcpSseEndpoint( sseTransportManager.removeTransport(transport.sessionId) } - server.connect(transport) + server.createSession(transport) logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 8c010560..15d61c19 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -57,12 +57,14 @@ public class ServerOptions(public val capabilities: ServerCapabilities, enforceS * @param options Configuration options for the server. * @param instructionsProvider Optional provider for instructions from the server to the client about how to use * this server. The provider is called each time a new session is started to support dynamic instructions. + * @param block A block to configure the mcp server. */ public open class Server( protected val serverInfo: Implementation, protected val options: ServerOptions, protected val instructionsProvider: (() -> String)? = null, + block: Server.() -> Unit = {}, ) { /** * Alternative constructor that provides the instructions directly as a string. @@ -70,12 +72,14 @@ public open class Server( * @param serverInfo Information about this server implementation (name, version). * @param options Configuration options for the server. * @param instructions Instructions from the server to the client about how to use this server. + * @param block A block to configure the mcp server. */ public constructor( serverInfo: Implementation, options: ServerOptions, instructions: String, - ) : this(serverInfo, options, { instructions }) + block: Server.() -> Unit = {}, + ) : this(serverInfo, options, { instructions }, block) private val sessions = atomic(persistentListOf()) @@ -98,6 +102,10 @@ public open class Server( public val resources: Map get() = _resources.value + init { + block(this) + } + public suspend fun close() { logger.debug { "Closing MCP server" } sessions.value.forEach { it.close() } @@ -111,7 +119,21 @@ public open class Server( * @param transport The transport layer to connect the session with. * @return The initialized and connected server session. */ - public suspend fun connect(transport: Transport): ServerSession { + @Deprecated( + "Use createSession(transport) instead.", + ReplaceWith("createSession(transport)"), + DeprecationLevel.WARNING, + ) + public suspend fun connect(transport: Transport): ServerSession = createSession(transport) + + /** + * Starts a new server session with the given transport and initializes + * internal request handlers based on the server's capabilities. + * + * @param transport The transport layer to connect the session with. + * @return The initialized and connected server session. + */ + public suspend fun createSession(transport: Transport): ServerSession { val session = ServerSession(serverInfo, options, instructionsProvider?.invoke()) // Internal handlers for tools diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index 1ceb764b..a5a24927 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -69,7 +69,7 @@ internal suspend fun WebSocketServerSession.mcpWebSocketEndpoint(block: () -> Se val server = block() var session: ServerSession? = null try { - session = server.connect(transport) + session = server.createSession(transport) awaitCancellation() } catch (e: CancellationException) { session?.close() @@ -103,7 +103,7 @@ public fun Route.mcpWebSocket(options: ServerOptions? = null, handler: suspend S ) public fun Route.mcpWebSocket(block: () -> Server) { webSocket { - block().connect(createMcpTransport(this)) + block().createSession(createMcpTransport(this)) } } @@ -190,7 +190,7 @@ private suspend fun Route.createMcpServer( ), ) - server.connect(transport) + server.createSession(transport) handler(server) server.close() } diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index c7f8b60d..41bf5458 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -24,7 +24,6 @@ kotlin { } jvmTest { dependencies { - implementation(kotlin("test-junit5")) implementation(libs.awaitility) runtimeOnly(libs.slf4j.simple) } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index d132ae5f..b563ba05 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -258,7 +258,7 @@ class ClientTest { client.connect(clientTransport) }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) }, ).joinAll() @@ -322,7 +322,7 @@ class ClientTest { println("Client connected") }, launch { - server.connect(serverTransport) + server.createSession(serverTransport) println("Server connected") }, ).joinAll() @@ -379,7 +379,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -439,7 +439,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -502,7 +502,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -594,7 +594,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -680,7 +680,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -812,7 +812,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -859,7 +859,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() @@ -939,7 +939,7 @@ class ClientTest { println("Client connected") }, launch { - serverSessionResult.complete(server.connect(serverTransport)) + serverSessionResult.complete(server.createSession(serverTransport)) println("Server connected") }, ).joinAll() diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index fa4ae020..e81332e6 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -48,7 +48,7 @@ class SseTransportTest : BaseTransportTest() { val actualPort = server.actualPort() - val client = HttpClient { + val transport = HttpClient { install(ClientSSE) }.mcpSseTransport { url { @@ -58,7 +58,7 @@ class SseTransportTest : BaseTransportTest() { } try { - testClientOpenClose(client) + testTransportRead(transport) } finally { server.stopSuspend() } @@ -76,7 +76,7 @@ class SseTransportTest : BaseTransportTest() { val actualPort = server.actualPort() - val client = HttpClient { + val transport = HttpClient { install(ClientSSE) }.mcpSseTransport { url { @@ -86,7 +86,7 @@ class SseTransportTest : BaseTransportTest() { } try { - testClientRead(client) + testTransportRead(transport) } finally { server.stopSuspend() } @@ -104,7 +104,7 @@ class SseTransportTest : BaseTransportTest() { val actualPort = server.actualPort() - val client = HttpClient { + val transport = HttpClient { install(ClientSSE) }.mcpSseTransport { url { @@ -115,7 +115,7 @@ class SseTransportTest : BaseTransportTest() { } try { - testClientRead(client) + testTransportRead(transport) } finally { server.stopSuspend() } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketTransportTest.kt index a4bf7a3f..97330718 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketTransportTest.kt @@ -22,7 +22,7 @@ class WebSocketTransportTest : BaseTransportTest() { install(io.ktor.client.plugins.websocket.WebSockets) }.mcpWebSocketTransport() - testClientOpenClose(client) + testTransportOpenClose(client) } @Test @@ -41,11 +41,11 @@ class WebSocketTransportTest : BaseTransportTest() { } } - val client = createClient { + val transport = createClient { install(io.ktor.client.plugins.websocket.WebSockets) }.mcpWebSocketTransport() - testClientRead(client) + testTransportRead(transport) clientFinished.complete(Unit) } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt deleted file mode 100644 index 29718128..00000000 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ /dev/null @@ -1,199 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration - -import io.ktor.client.HttpClient -import io.ktor.client.plugins.sse.SSE -import io.ktor.server.application.install -import io.ktor.server.cio.CIOApplicationEngine -import io.ktor.server.engine.EmbeddedServer -import io.ktor.server.engine.embeddedServer -import io.ktor.server.routing.routing -import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest -import io.modelcontextprotocol.kotlin.sdk.GetPromptResult -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptMessage -import io.modelcontextprotocol.kotlin.sdk.Role -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.TextContent -import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport -import io.modelcontextprotocol.kotlin.sdk.server.Server -import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions -import io.modelcontextprotocol.kotlin.sdk.server.mcp -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.withContext -import kotlin.test.Ignore -import kotlin.test.Test -import kotlin.test.assertTrue -import kotlin.time.Duration.Companion.seconds -import io.ktor.client.engine.cio.CIO as ClientCIO -import io.ktor.server.cio.CIO as ServerCIO -import io.ktor.server.sse.SSE as ServerSSE - -class SseIntegrationTest { - @Test - @Ignore // Ignored because it doesn’t work with wasm/js in Ktor 3.2.3 - fun `client should be able to connect to sse server`() = runTest(timeout = 5.seconds) { - var server: EmbeddedServer? = null - var client: Client? = null - - try { - withContext(Dispatchers.Default) { - server = initServer() - val port = server.engine.resolvedConnectors().first().port - client = initClient(serverPort = port) - } - } finally { - client?.close() - server?.stopSuspend(1000, 2000) - } - } - - /** - * Test Case #1: One opened connection, a client gets a prompt - * - * 1. Open SSE from Client A. - * 2. Send a POST request from Client A to POST /prompts/get. - * 3. Observe that Client A receives a response related to it. - */ - @Test - @Ignore // Ignored because it doesn’t work with wasm/js in Ktor 3.2.3 - fun `single sse connection`() = runTest(timeout = 5.seconds) { - var server: EmbeddedServer? = null - var client: Client? = null - try { - withContext(Dispatchers.Default) { - server = initServer() - val port = server.engine.resolvedConnectors().first().port - client = initClient("Client A", port) - - val promptA = getPrompt(client, "Client A") - assertTrue { "Client A" in promptA } - } - } finally { - client?.close() - server?.stopSuspend(1000, 2000) - } - } - - /** - * Test Case #1: Two open connections, each client gets a client-specific prompt - * - * 1. Open SSE connection #1 from Client A and note the sessionId= value. - * 2. Open SSE connection #2 from Client B and note the sessionId= value. - * 3. Send a POST request to POST /message with the corresponding sessionId#1. - * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. - */ - @Test - @Ignore // Ignored because it doesn’t work with wasm/js in Ktor 3.2.3 - fun `multiple sse connections`() = runTest(timeout = 5.seconds) { - var server: EmbeddedServer? = null - var clientA: Client? = null - var clientB: Client? = null - - try { - withContext(Dispatchers.Default) { - server = initServer() - val port = server.engine.resolvedConnectors().first().port - - clientA = initClient("Client A", port) - clientB = initClient("Client B", port) - - // Step 3: Send a prompt request from Client A - val promptA = getPrompt(clientA, "Client A") - // Step 4: Send a prompt request from Client B - val promptB = getPrompt(clientB, "Client B") - - assertTrue { "Client A" in promptA } - assertTrue { "Client B" in promptB } - } - } finally { - clientA?.close() - clientB?.close() - server?.stopSuspend(1000, 2000) - } - } - - private suspend fun initClient(name: String = "", serverPort: Int): Client { - val client = Client( - Implementation(name = name, version = "1.0.0"), - ) - - val httpClient = HttpClient(ClientCIO) { - install(SSE) - } - - // Create a transport wrapper that captures the session ID and received messages - val transport = httpClient.mcpSseTransport { - url { - host = URL - port = serverPort - } - } - - client.connect(transport) - - return client - } - - private suspend fun initServer(): EmbeddedServer { - val server = Server( - Implementation(name = "sse-server", version = "1.0.0"), - ServerOptions( - capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), - ), - ) - - server.addPrompt( - name = "prompt", - description = "Prompt description", - arguments = listOf( - PromptArgument( - name = "client", - description = "Client name who requested a prompt", - required = true, - ), - ), - ) { request -> - GetPromptResult( - "Prompt for ${request.name}", - messages = listOf( - PromptMessage( - role = Role.user, - content = TextContent("Prompt for client ${request.arguments?.get("client")}"), - ), - ), - ) - } - - val ktorServer = embeddedServer(ServerCIO, host = URL, port = PORT) { - install(ServerSSE) - routing { - mcp { server } - } - } - - return ktorServer.startSuspend(wait = false) - } - - /** - * Retrieves a prompt result using the provided client and client name. - */ - private suspend fun getPrompt(client: Client, clientName: String): String { - val response = client.getPrompt( - GetPromptRequest( - "prompt", - arguments = mapOf("client" to clientName), - ), - ) - - return (response.messages.first().content as? TextContent)?.text - ?: error("Failed to receive prompt for Client $clientName") - } - - companion object { - private const val URL = "127.0.0.1" - private const val PORT = 0 - } -} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index acb2f278..26404b62 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -11,23 +11,24 @@ import kotlin.test.assertTrue import kotlin.test.fail abstract class BaseTransportTest { - protected suspend fun testClientOpenClose(client: Transport) { - client.onError { error -> + + protected suspend fun testTransportOpenClose(transport: Transport) { + transport.onError { error -> fail("Unexpected error: $error") } var didClose = false - client.onClose { didClose = true } + transport.onClose { didClose = true } - client.start() + transport.start() assertFalse(didClose, "Transport should not be closed immediately after start") - client.close() + transport.close() assertTrue(didClose, "Transport should be closed after close() call") } - protected suspend fun testClientRead(client: Transport) { - client.onError { error -> + protected suspend fun testTransportRead(transport: Transport) { + transport.onError { error -> error.printStackTrace() fail("Unexpected error: $error") } @@ -40,23 +41,23 @@ abstract class BaseTransportTest { val readMessages = mutableListOf() val finished = CompletableDeferred() - client.onMessage { message -> + transport.onMessage { message -> readMessages.add(message) if (message == messages.last()) { finished.complete(Unit) } } - client.start() + transport.start() for (message in messages) { - client.send(message) + transport.send(message) } finished.await() assertEquals(messages, readMessages, "Assert messages received") - client.close() + transport.close() } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt index 6d0da329..7d0fa360 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt @@ -22,7 +22,7 @@ class StdioClientTransportTest : BaseTransportTest() { output = output, ) - testClientOpenClose(client) + testTransportOpenClose(client) process.destroy() } @@ -40,7 +40,7 @@ class StdioClientTransportTest : BaseTransportTest() { output = output, ) - testClientRead(client) + testTransportRead(client) process.waitFor() process.destroy() @@ -60,7 +60,7 @@ class StdioClientTransportTest : BaseTransportTest() { output = output, ) - testClientRead(client) + testTransportRead(client) process.waitFor() process.destroy() diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt index 563fa853..66b166cc 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -139,7 +139,7 @@ abstract class KotlinTestBase { // Start server transport by connecting the server runBlocking { - server.connect(serverTransport) + server.createSession(serverTransport) } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/AbstractSseIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/AbstractSseIntegrationTest.kt new file mode 100644 index 00000000..26289250 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/AbstractSseIntegrationTest.kt @@ -0,0 +1,105 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.sse + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.install +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +open class AbstractSseIntegrationTest { + + suspend fun EmbeddedServer<*, *>.actualPort() = engine.resolvedConnectors().single().port + + suspend fun initTestClient(serverPort: Int, name: String? = null): Client { + val client = Client( + Implementation(name = name ?: DEFAULT_CLIENT_NAME, version = VERSION), + ) + + val httpClient = HttpClient(ClientCIO) { + install(SSE) + } + + // Create a transport wrapper that captures the session ID and received messages + val transport = httpClient.mcpSseTransport { + url { + host = URL + port = serverPort + } + } + + client.connect(transport) + + return client + } + + suspend fun initTestServer( + name: String? = null, + ): EmbeddedServer { + val server = Server( + Implementation(name = name ?: DEFAULT_SERVER_NAME, version = VERSION), + ServerOptions( + capabilities = ServerCapabilities(prompts = ServerCapabilities.Prompts(listChanged = true)), + ), + ) { + addPrompt( + name = "prompt", + description = "Prompt description", + arguments = listOf( + PromptArgument( + name = "client", + description = "Client name who requested a prompt", + required = true, + ), + ), + ) { request -> + GetPromptResult( + "Prompt for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent("Prompt for client ${request.arguments?.get("client")}"), + ), + ), + ) + } + } + + val ktorServer = embeddedServer( + ServerCIO, + host = URL, + port = PORT, + ) { + install(ServerSSE) + routing { + mcp { server } + } + } + + return ktorServer.startSuspend(wait = false) + } + + companion object { + private const val DEFAULT_CLIENT_NAME = "sse-test-client" + private const val DEFAULT_SERVER_NAME = "sse-test-server" + private const val VERSION = "1.0.0" + private const val URL = "127.0.0.1" + private const val PORT = 0 + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseIntegrationTest.kt new file mode 100644 index 00000000..33e80705 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseIntegrationTest.kt @@ -0,0 +1,111 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.sse + +import io.ktor.server.cio.CIOApplicationEngine +import io.ktor.server.engine.EmbeddedServer +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlin.test.Test +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class SSeIntegrationTest : AbstractSseIntegrationTest() { + + @Test + fun `client should be able to connect to sse server`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var client: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initTestServer() + val port = server.engine.resolvedConnectors().single().port + client = initTestClient(serverPort = port) + } + } finally { + client?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Test Case #1: One opened connection, a client gets a prompt + * + * 1. Open SSE from Client A. + * 2. Send a POST request from Client A to POST /prompts/get. + * 3. Observe that Client A receives a response related to it. + */ + @Test + fun `single sse connection`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var client: Client? = null + try { + withContext(Dispatchers.Default) { + server = initTestServer() + val port = server.engine.resolvedConnectors().single().port + client = initTestClient(port, "Client A") + + val promptA = getPrompt(client, "Client A") + assertTrue { "Client A" in promptA } + } + } finally { + client?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Test Case #1: Two open connections, each client gets a client-specific prompt + * + * 1. Open SSE connection #1 from Client A and note the sessionId= value. + * 2. Open SSE connection #2 from Client B and note the sessionId= value. + * 3. Send a POST request to POST /message with the corresponding sessionId#1. + * 4. Observe that Client B (connection #2) receives a response related to sessionId#1. + */ + @Test + fun `multiple sse connections`() = runTest(timeout = 5.seconds) { + var server: EmbeddedServer? = null + var clientA: Client? = null + var clientB: Client? = null + + try { + withContext(Dispatchers.Default) { + server = initTestServer() + val port = server.engine.resolvedConnectors().first().port + + clientA = initTestClient(port, "Client A") + clientB = initTestClient(port, "Client B") + + // Step 3: Send a prompt request from Client A + val promptA = getPrompt(clientA, "Client A") + // Step 4: Send a prompt request from Client B + val promptB = getPrompt(clientB, "Client B") + + assertTrue { "Client A" in promptA } + assertTrue { "Client B" in promptB } + } + } finally { + clientA?.close() + clientB?.close() + server?.stopSuspend(1000, 2000) + } + } + + /** + * Retrieves a prompt result using the provided client and client name. + */ + private suspend fun getPrompt(client: Client, clientName: String): String { + val response = client.getPrompt( + GetPromptRequest( + "prompt", + arguments = mapOf("client" to clientName), + ), + ) + + return (response.messages.first().content as? TextContent)?.text + ?: error("Failed to receive prompt for Client $clientName") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt index 801f5820..21593fe2 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt @@ -412,7 +412,7 @@ abstract class TsTestBase { // Connect server in a background thread to avoid blocking val serverThread = Thread { try { - kotlinx.coroutines.runBlocking { server.connect(transport) } + kotlinx.coroutines.runBlocking { server.createSession(transport) } } catch (e: Exception) { println("[STDIO-SERVER] Error connecting: ${e.message}") } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt index 5cb0f61b..c51454df 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt @@ -53,7 +53,6 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.contentOrNull import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.jsonPrimitive -import org.awaitility.Awaitility.await import java.util.UUID import java.util.concurrent.ConcurrentHashMap @@ -134,7 +133,7 @@ class KotlinServerForTsClient { val serverThread = Thread { runBlocking { - mcpServer.connect(transport) + mcpServer.createSession(transport) } } serverThread.start() diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt index 088444d6..e0bf1c64 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt @@ -53,7 +53,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Remove the tool val result = server.removeTool("test-tool") @@ -90,7 +90,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Try to remove a non-existent tool val result = server.removeTool("non-existent-tool") @@ -147,7 +147,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Remove the tools val result = server.removeTools(listOf("test-tool-1", "test-tool-2")) @@ -186,7 +186,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Remove the prompt val result = server.removePrompt(testPrompt.name) @@ -232,7 +232,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Remove the prompts val result = server.removePrompts(listOf(testPrompt1.name, testPrompt2.name)) @@ -281,7 +281,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Remove the resource val result = server.removeResource(testResourceUri) @@ -347,7 +347,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Remove the resources val result = server.removeResources(listOf(testResourceUri1, testResourceUri2)) @@ -384,7 +384,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Try to remove a non-existent prompt val result = server.removePrompt("non-existent-prompt") @@ -442,7 +442,7 @@ class ServerTest { // Connect client and server launch { client.connect(clientTransport) } - launch { server.connect(serverTransport) } + launch { server.createSession(serverTransport) } // Try to remove a non-existent resource val result = server.removeResource("non-existent-resource") @@ -486,7 +486,7 @@ class ServerTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client(clientInfo = Implementation(name = "test client", version = "1.0")) - server.connect(serverTransport) + server.createSession(serverTransport) client.connect(clientTransport) assertEquals(instructions, client.serverInstructions) @@ -505,7 +505,7 @@ class ServerTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client(clientInfo = Implementation(name = "test client", version = "1.0")) - server.connect(serverTransport) + server.createSession(serverTransport) client.connect(clientTransport) assertEquals(instructions, client.serverInstructions) @@ -522,7 +522,7 @@ class ServerTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client(clientInfo = Implementation(name = "test client", version = "1.0")) - server.connect(serverTransport) + server.createSession(serverTransport) client.connect(clientTransport) assertNull(client.serverInstructions)