diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index b78db2a5..c69cd1c7 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -17,6 +17,7 @@ jreleaser = "1.19.0" binaryCompatibilityValidatorPlugin = "0.18.1" slf4j = "2.0.17" kotest = "5.9.1" +awaitility = "4.3.0" # Samples mcp-kotlin = "0.7.0" @@ -45,11 +46,12 @@ ktor-server-websockets = { group = "io.ktor", name = "ktor-server-websockets", v ktor-server-core = { group = "io.ktor", name = "ktor-server-core", version.ref = "ktor" } # Testing +awaitility = { group = "org.awaitility", name = "awaitility-kotlin", version.ref = "awaitility" } +kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json", version.ref = "kotest" } kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "coroutines" } -ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" } ktor-client-mock = { group = "io.ktor", name = "ktor-client-mock", version.ref = "ktor" } +ktor-server-test-host = { group = "io.ktor", name = "ktor-server-test-host", version.ref = "ktor" } slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" } -kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json", version.ref = "kotest" } # Samples ktor-client-cio = { group = "io.ktor", name = "ktor-client-cio", version.ref = "ktor" } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 580c3982..675bfea0 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -36,11 +36,10 @@ import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.encodeToJsonElement import kotlinx.serialization.serializer -import kotlin.collections.get import kotlin.reflect.KType import kotlin.reflect.typeOf import kotlin.time.Duration -import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds private val LOGGER = KotlinLogging.logger { } @@ -85,7 +84,7 @@ public open class ProtocolOptions( /** * The default request timeout. */ -public val DEFAULT_REQUEST_TIMEOUT: Duration = 60000.milliseconds +public val DEFAULT_REQUEST_TIMEOUT: Duration = 60.seconds /** * Options that can be given per request. diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index 36496c37..012619b9 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -24,6 +24,7 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) + implementation(libs.awaitility) runtimeOnly(libs.slf4j.simple) } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt index c367cf12..e0ccd39b 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -17,6 +17,7 @@ import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.mcp import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout +import org.awaitility.kotlin.await import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import kotlin.time.Duration.Companion.seconds @@ -27,7 +28,7 @@ import io.ktor.server.sse.SSE as ServerSSE abstract class KotlinTestBase { protected val host = "localhost" - protected abstract val port: Int + protected var port: Int = 0 protected lateinit var server: Server protected lateinit var client: Client @@ -39,6 +40,12 @@ abstract class KotlinTestBase { @BeforeEach fun setUp() { setupServer() + await + .ignoreExceptions() + .until { + port = runBlocking { serverEngine.engine.resolvedConnectors().first().port } + port != 0 + } runBlocking { setupClient() } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt index 559129c3..31376332 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt @@ -7,9 +7,10 @@ import io.modelcontextprotocol.kotlin.sdk.PromptMessage import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals @@ -18,8 +19,6 @@ import kotlin.test.assertTrue class PromptEdgeCasesTest : KotlinTestBase() { - override val port = 3008 - private val basicPromptName = "basic-prompt" private val basicPromptDescription = "A basic prompt for testing" @@ -183,7 +182,7 @@ class PromptEdgeCasesTest : KotlinTestBase() { } @Test - fun testBasicPrompt() = runTest { + fun testBasicPrompt() = runBlocking(Dispatchers.IO) { val testName = "Alice" val result = client.getPrompt( GetPromptRequest( @@ -215,7 +214,7 @@ class PromptEdgeCasesTest : KotlinTestBase() { } @Test - fun testComplexPromptWithManyArguments() = runTest { + fun testComplexPromptWithManyArguments() = runBlocking(Dispatchers.IO) { val arguments = (1..10).associate { i -> "arg$i" to "value$i" } val result = client.getPrompt( @@ -253,7 +252,7 @@ class PromptEdgeCasesTest : KotlinTestBase() { } @Test - fun testLargePrompt() = runTest { + fun testLargePrompt() = runBlocking(Dispatchers.IO) { val result = client.getPrompt( GetPromptRequest( name = largePromptName, @@ -275,7 +274,7 @@ class PromptEdgeCasesTest : KotlinTestBase() { } @Test - fun testSpecialCharacters() = runTest { + fun testSpecialCharacters() = runBlocking(Dispatchers.IO) { val result = client.getPrompt( GetPromptRequest( name = specialCharsPromptName, diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt index a609c2ba..54fd5fc8 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt @@ -9,7 +9,7 @@ import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows @@ -19,7 +19,6 @@ import kotlin.test.assertTrue class PromptIntegrationTest : KotlinTestBase() { - override val port = 3004 private val testPromptName = "greeting" private val testPromptDescription = "A simple greeting prompt" private val complexPromptName = "multimodal-prompt" @@ -219,7 +218,7 @@ class PromptIntegrationTest : KotlinTestBase() { } @Test - fun testListPrompts() = runTest { + fun testListPrompts() = runBlocking(Dispatchers.IO) { val result = client.listPrompts() assertNotNull(result, "List prompts result should not be null") @@ -247,7 +246,7 @@ class PromptIntegrationTest : KotlinTestBase() { } @Test - fun testGetPrompt() = runTest { + fun testGetPrompt() = runBlocking(Dispatchers.IO) { val testName = "Alice" val result = client.getPrompt( GetPromptRequest( @@ -290,7 +289,7 @@ class PromptIntegrationTest : KotlinTestBase() { } @Test - fun testMissingRequiredArguments() = runTest { + fun testMissingRequiredArguments() = runBlocking(Dispatchers.IO) { val promptsList = client.listPrompts() assertNotNull(promptsList, "Prompts list should not be null") val strictPrompt = promptsList.prompts.find { it.name == strictPromptName } @@ -364,7 +363,7 @@ class PromptIntegrationTest : KotlinTestBase() { } @Test - fun testComplexContentTypes() = runTest { + fun testComplexContentTypes() = runBlocking(Dispatchers.IO) { val topic = "artificial intelligence" val result = client.getPrompt( GetPromptRequest( @@ -418,7 +417,7 @@ class PromptIntegrationTest : KotlinTestBase() { } @Test - fun testMultipleMessagesAndRoles() = runTest { + fun testMultipleMessagesAndRoles() = runBlocking(Dispatchers.IO) { val topic = "climate change" val result = client.getPrompt( GetPromptRequest( diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt index 165e6936..9e23a3c0 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt @@ -9,9 +9,10 @@ import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest import io.modelcontextprotocol.kotlin.sdk.TextResourceContents import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import java.util.concurrent.atomic.AtomicBoolean @@ -21,8 +22,6 @@ import kotlin.test.assertTrue class ResourceEdgeCasesTest : KotlinTestBase() { - override val port = 3007 - private val testResourceUri = "test://example.txt" private val testResourceName = "Test Resource" private val testResourceDescription = "A test resource for integration testing" @@ -129,7 +128,7 @@ class ResourceEdgeCasesTest : KotlinTestBase() { } @Test - fun testBinaryResource() = runTest { + fun testBinaryResource() = runBlocking(Dispatchers.IO) { val result = client.readResource(ReadResourceRequest(uri = binaryResourceUri)) assertNotNull(result, "Read resource result should not be null") @@ -142,7 +141,7 @@ class ResourceEdgeCasesTest : KotlinTestBase() { } @Test - fun testLargeResource() = runTest { + fun testLargeResource() = runBlocking(Dispatchers.IO) { val result = client.readResource(ReadResourceRequest(uri = largeResourceUri)) assertNotNull(result, "Read resource result should not be null") @@ -172,7 +171,7 @@ class ResourceEdgeCasesTest : KotlinTestBase() { } @Test - fun testDynamicResource() = runTest { + fun testDynamicResource() = runBlocking(Dispatchers.IO) { val initialResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) assertNotNull(initialResult, "Initial read result should not be null") val initialContent = (initialResult.contents.firstOrNull() as? TextResourceContents)?.text @@ -188,7 +187,7 @@ class ResourceEdgeCasesTest : KotlinTestBase() { } @Test - fun testResourceAddAndRemove() = runTest { + fun testResourceAddAndRemove() = runBlocking(Dispatchers.IO) { val initialList = client.listResources() assertNotNull(initialList, "Initial list result should not be null") val initialCount = initialList.resources.size @@ -261,7 +260,7 @@ class ResourceEdgeCasesTest : KotlinTestBase() { @Test fun testSubscribeAndUnsubscribe() { - runTest { + runBlocking(Dispatchers.IO) { val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) assertNotNull(subscribeResult, "Subscribe result should not be null") diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt index c467b2a1..5ea9bbd0 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt @@ -8,7 +8,8 @@ import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest import io.modelcontextprotocol.kotlin.sdk.TextResourceContents import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -16,7 +17,6 @@ import kotlin.test.assertTrue class ResourceIntegrationTest : KotlinTestBase() { - override val port = 3005 private val testResourceUri = "test://example.txt" private val testResourceName = "Test Resource" private val testResourceDescription = "A test resource for integration testing" @@ -57,7 +57,7 @@ class ResourceIntegrationTest : KotlinTestBase() { } @Test - fun testListResources() = runTest { + fun testListResources() = runBlocking(Dispatchers.IO) { val result = client.listResources() assertNotNull(result, "List resources result should not be null") @@ -70,7 +70,7 @@ class ResourceIntegrationTest : KotlinTestBase() { } @Test - fun testReadResource() = runTest { + fun testReadResource() = runBlocking(Dispatchers.IO) { val result = client.readResource(ReadResourceRequest(uri = testResourceUri)) assertNotNull(result, "Read resource result should not be null") @@ -83,7 +83,7 @@ class ResourceIntegrationTest : KotlinTestBase() { @Test fun testSubscribeAndUnsubscribe() { - runTest { + runBlocking(Dispatchers.IO) { val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) assertNotNull(subscribeResult, "Subscribe result should not be null") diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt index a0dc9ba0..c30cffb0 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -7,10 +7,11 @@ import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive @@ -26,8 +27,6 @@ import kotlin.test.assertTrue class ToolEdgeCasesTest : KotlinTestBase() { - override val port = 3009 - private val basicToolName = "basic-tool" private val basicToolDescription = "A basic tool for testing" @@ -279,63 +278,60 @@ class ToolEdgeCasesTest : KotlinTestBase() { } @Test - fun testBasicTool() { - runTest { - val testText = "Hello, world!" - val arguments = mapOf("text" to testText) + fun testBasicTool(): Unit = runBlocking(Dispatchers.IO) { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) - val result = client.callTool(basicToolName, arguments) as CallToolResultBase + val result = client.callTool(basicToolName, arguments) as CallToolResultBase - val expectedToolResult = "[TextContent(text=Echo: Hello, world!, annotations=null)]" - assertEquals(expectedToolResult, result.content.toString(), "Unexpected tool result") + val expectedToolResult = "[TextContent(text=Echo: Hello, world!, annotations=null)]" + assertEquals(expectedToolResult, result.content.toString(), "Unexpected tool result") - val actualContent = result.structuredContent.toString() - val expectedContent = """ - { - "result" : "Hello, world!" - } - """.trimIndent() + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "result" : "Hello, world!" + } + """.trimIndent() - actualContent shouldEqualJson expectedContent - } + actualContent shouldEqualJson expectedContent } @Test - fun testComplexNestedSchema() { - runTest { - val userJson = buildJsonObject { - put("name", JsonPrimitive("John Galt")) - put("age", JsonPrimitive(30)) - put( - "address", - buildJsonObject { - put("street", JsonPrimitive("123 Main St")) - put("city", JsonPrimitive("New York")) - put("country", JsonPrimitive("USA")) - }, - ) - } + fun testComplexNestedSchema(): Unit = runBlocking(Dispatchers.IO) { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Galt")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, + ) + } - val optionsJson = buildJsonArray { - add(JsonPrimitive("option1")) - add(JsonPrimitive("option2")) - add(JsonPrimitive("option3")) - } + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } - val arguments = buildJsonObject { - put("user", userJson) - put("options", optionsJson) - } + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) + } - val result = client.callTool( - CallToolRequest( - name = complexToolName, - arguments = arguments, - ), - ) as CallToolResultBase + val result = client.callTool( + CallToolRequest( + name = complexToolName, + arguments = arguments, + ), + ) as CallToolResultBase - val actualContent = result.structuredContent.toString() - val expectedContent = """ + val actualContent = result.structuredContent.toString() + val expectedContent = """ { "name" : "John Galt", "age" : 30, @@ -346,63 +342,58 @@ class ToolEdgeCasesTest : KotlinTestBase() { }, "options" : [ "option1", "option2", "option3" ] } - """.trimIndent() + """.trimIndent() - actualContent shouldEqualJson expectedContent - } + actualContent shouldEqualJson expectedContent } @Test - fun testLargeResponse() { - runTest { - val size = 10 - val arguments = mapOf("size" to size) + fun testLargeResponse(): Unit = runBlocking(Dispatchers.IO) { + val size = 10 + val arguments = mapOf("size" to size) - val result = client.callTool(largeToolName, arguments) as CallToolResultBase + val result = client.callTool(largeToolName, arguments) as CallToolResultBase - val content = result.content.firstOrNull() as TextContent - assertNotNull(content, "Tool result content should be TextContent") + val content = result.content.firstOrNull() as TextContent + assertNotNull(content, "Tool result content should be TextContent") - val actualContent = result.structuredContent.toString() - val expectedContent = """ + val actualContent = result.structuredContent.toString() + val expectedContent = """ { "size" : 10000 } - """.trimIndent() + """.trimIndent() - actualContent shouldEqualJson expectedContent - } + actualContent shouldEqualJson expectedContent } @Test - fun testSlowTool() { - runTest { - val delay = 500 - val arguments = mapOf("delay" to delay) + fun testSlowTool(): Unit = runBlocking(Dispatchers.IO) { + val delay = 500 + val arguments = mapOf("delay" to delay) - val startTime = System.currentTimeMillis() - val result = client.callTool(slowToolName, arguments) as CallToolResultBase - val endTime = System.currentTimeMillis() + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) as CallToolResultBase + val endTime = System.currentTimeMillis() - val content = result.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") - assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") - val actualContent = result.structuredContent.toString() - val expectedContent = """ + val actualContent = result.structuredContent.toString() + val expectedContent = """ { "delay" : 500 } - """.trimIndent() + """.trimIndent() - actualContent shouldEqualJson expectedContent - } + actualContent shouldEqualJson expectedContent } @Test fun testSpecialCharacters() { - runTest { + runBlocking(Dispatchers.IO) { val arguments = mapOf("special" to specialCharsContent) val result = client.callTool(specialCharsToolName, arguments) as CallToolResultBase diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt index 044237a2..84ae233d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -8,7 +8,7 @@ import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonPrimitive @@ -26,8 +26,6 @@ import kotlin.test.assertNotNull import kotlin.test.assertTrue class ToolIntegrationTest : KotlinTestBase() { - - override val port = 3006 private val testToolName = "echo" private val testToolDescription = "A simple echo tool that returns the input text" private val complexToolName = "calculator" @@ -302,7 +300,7 @@ class ToolIntegrationTest : KotlinTestBase() { } @Test - fun testListTools() = runTest { + fun testListTools(): Unit = runBlocking(Dispatchers.IO) { val result = client.listTools() assertNotNull(result, "List utils result should not be null") @@ -319,43 +317,40 @@ class ToolIntegrationTest : KotlinTestBase() { } @Test - fun testCallTool() { - runTest { - val testText = "Hello, world!" - val arguments = mapOf("text" to testText) + fun testCallTool(): Unit = runBlocking(Dispatchers.IO) { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) - val result = client.callTool(testToolName, arguments) as CallToolResultBase + val result = client.callTool(testToolName, arguments) as CallToolResultBase - val actualContent = result.structuredContent.toString() - val expectedContent = """ + val actualContent = result.structuredContent.toString() + val expectedContent = """ {"result":"Hello, world!"} - """.trimIndent() + """.trimIndent() - actualContent shouldEqualJson expectedContent - } + actualContent shouldEqualJson expectedContent } @Test - fun testComplexInputSchemaTool() { - runTest { - val toolsList = client.listTools() - assertNotNull(toolsList, "Tools list should not be null") - val calculatorTool = toolsList.tools.find { it.name == complexToolName } - assertNotNull(calculatorTool, "Calculator tool should be in the list") - - val arguments = mapOf( - "operation" to "multiply", - "a" to 5.5, - "b" to 2.0, - "precision" to 3, - "showSteps" to true, - "tags" to listOf("test", "calculator", "integration"), - ) + fun testComplexInputSchemaTool(): Unit = runBlocking(Dispatchers.IO) { + val toolsList = client.listTools() + assertNotNull(toolsList, "Tools list should not be null") + val calculatorTool = toolsList.tools.find { it.name == complexToolName } + assertNotNull(calculatorTool, "Calculator tool should be in the list") - val result = client.callTool(complexToolName, arguments) as CallToolResultBase + val arguments = mapOf( + "operation" to "multiply", + "a" to 5.5, + "b" to 2.0, + "precision" to 3, + "showSteps" to true, + "tags" to listOf("test", "calculator", "integration"), + ) + + val result = client.callTool(complexToolName, arguments) as CallToolResultBase - val actualContent = result.structuredContent.toString() - val expectedContent = """ + val actualContent = result.structuredContent.toString() + val expectedContent = """ { "operation" : "multiply", "a" : 5.5, @@ -365,14 +360,13 @@ class ToolIntegrationTest : KotlinTestBase() { "precision" : 3, "tags" : [ ] } - """.trimIndent() + """.trimIndent() - actualContent shouldEqualJson expectedContent - } + actualContent shouldEqualJson expectedContent } @Test - fun testToolErrorHandling() = runTest { + fun testToolErrorHandling(): Unit = runBlocking(Dispatchers.IO) { val successArgs = mapOf("errorType" to "none") val successResult = client.callTool(errorToolName, successArgs) @@ -420,7 +414,7 @@ class ToolIntegrationTest : KotlinTestBase() { } @Test - fun testMultiContentTool() = runTest { + fun testMultiContentTool(): Unit = runBlocking(Dispatchers.IO) { val testText = "Test multi-content" val arguments = mapOf( "text" to testText, diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt index 25ead220..7b15fbc7 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -3,7 +3,7 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.runBlocking @@ -64,7 +64,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testNonExistentTool() = runTest { + fun testNonExistentTool(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val nonExistentToolName = "non-existent-tool" val arguments = mapOf("name" to "TestUser") @@ -85,7 +85,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testSpecialCharactersInArguments() = runTest { + fun testSpecialCharactersInArguments(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" val arguments = mapOf("name" to specialChars) @@ -107,7 +107,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testLargePayload() = runTest { + fun testLargePayload(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val largeName = "A".repeat(10 * 1024) val arguments = mapOf("name" to largeName) @@ -129,7 +129,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(60, unit = TimeUnit.SECONDS) - fun testConcurrentRequests() = runTest { + fun testConcurrentRequests(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val concurrentCount = 5 val responses = kotlinx.coroutines.coroutineScope { @@ -165,7 +165,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testInvalidArguments() = runTest { + fun testInvalidArguments(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val invalidArguments = mapOf( "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), @@ -196,7 +196,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testMultipleToolCalls() = runTest { + fun testMultipleToolCalls(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> repeat(10) { i -> val name = "SequentialClient$i" diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt index 13aaa73d..eca06be1 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -3,7 +3,7 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import org.junit.jupiter.api.AfterEach @@ -59,7 +59,7 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testKotlinClientConnectsToTypeScriptServer() = runTest { + fun testKotlinClientConnectsToTypeScriptServer(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> assertNotNull(client, "Client should be initialized") @@ -74,7 +74,7 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testListTools() = runTest { + fun testListTools(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val result = client.listTools() assertNotNull(result, "Tools list should not be null") @@ -92,7 +92,7 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testToolCall() = runTest { + fun testToolCall(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val testName = "TestUser" val arguments = mapOf("name" to testName) @@ -113,7 +113,7 @@ class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { @Test @Timeout(30, unit = TimeUnit.SECONDS) - fun testMultipleClients() = runTest { + fun testMultipleClients(): Unit = runBlocking(Dispatchers.IO) { val client1 = newClient(serverUrl) val client2 = newClient(serverUrl) try { diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt index 351406ef..ff4de3ca 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -1,6 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt index 7f27c5b7..6504b49e 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -1,6 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.integration.typescript -import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt deleted file mode 100644 index 46515610..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt +++ /dev/null @@ -1,13 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.utils - -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withContext - -object TestUtils { - fun runTest(block: suspend () -> T): T = runBlocking { - withContext(Dispatchers.IO) { - block() - } - } -} diff --git a/samples/kotlin-mcp-server/src/commonMain/kotlin/server.kt b/samples/kotlin-mcp-server/src/commonMain/kotlin/server.kt index a544d0ea..15e8cb0c 100644 --- a/samples/kotlin-mcp-server/src/commonMain/kotlin/server.kt +++ b/samples/kotlin-mcp-server/src/commonMain/kotlin/server.kt @@ -25,21 +25,20 @@ import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport import io.modelcontextprotocol.kotlin.sdk.server.mcp -import kotlin.collections.set fun configureServer(): Server { val server = Server( Implementation( name = "mcp-kotlin test server", - version = "0.1.0" + version = "0.1.0", ), ServerOptions( capabilities = ServerCapabilities( prompts = ServerCapabilities.Prompts(listChanged = true), resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), tools = ServerCapabilities.Tools(listChanged = true), - ) - ) + ), + ), ) server.addPrompt( @@ -49,18 +48,20 @@ fun configureServer(): Server { PromptArgument( name = "Project Name", description = "Project name for the new project", - required = true - ) - ) + required = true, + ), + ), ) { request -> GetPromptResult( "Description for ${request.name}", messages = listOf( PromptMessage( role = Role.user, - content = TextContent("Develop a kotlin project named ${request.arguments?.get("Project Name")}") - ) - ) + content = TextContent( + "Develop a kotlin project named ${request.arguments?.get("Project Name")}", + ), + ), + ), ) } @@ -68,10 +69,10 @@ fun configureServer(): Server { server.addTool( name = "kotlin-sdk-tool", description = "A test tool", - inputSchema = Tool.Input() + inputSchema = Tool.Input(), ) { request -> CallToolResult( - content = listOf(TextContent("Hello, world!")) + content = listOf(TextContent("Hello, world!")), ) } @@ -80,19 +81,19 @@ fun configureServer(): Server { uri = "https://search.com/", name = "Web Search", description = "Web search engine", - mimeType = "text/html" + mimeType = "text/html", ) { request -> ReadResourceResult( contents = listOf( - TextResourceContents("Placeholder content for ${request.uri}", request.uri, "text/html") - ) + TextResourceContents("Placeholder content for ${request.uri}", request.uri, "text/html"), + ), ) } return server } -suspend fun runSseMcpServerWithPlainConfiguration(port: Int): Unit { +suspend fun runSseMcpServerWithPlainConfiguration(port: Int) { val servers = ConcurrentMap() println("Starting sse server on port $port. ") println("Use inspector to connect to the http://localhost:$port/sse") @@ -139,7 +140,7 @@ suspend fun runSseMcpServerWithPlainConfiguration(port: Int): Unit { * @param port The port number on which the SSE MCP server will listen for client connections. * @return Unit This method does not return a value. */ -suspend fun runSseMcpServerUsingKtorPlugin(port: Int): Unit { +suspend fun runSseMcpServerUsingKtorPlugin(port: Int) { println("Starting sse server on port $port") println("Use inspector to connect to the http://localhost:$port/sse") @@ -148,4 +149,4 @@ suspend fun runSseMcpServerUsingKtorPlugin(port: Int): Unit { return@mcp configureServer() } }.startSuspend(wait = true) -} \ No newline at end of file +}