diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 0283a646..1f76a52f 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -18,7 +18,7 @@ kotlinx-io = "0.8.0" ktor = "3.2.3" logging = "7.0.13" slf4j = "2.0.17" -kotest = "6.0.4" +kotest = "5.9.1" # for JVM 1.8 awaitility = "4.3.0" mokksy = "0.6.1" diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt new file mode 100644 index 00000000..55e09da8 --- /dev/null +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt @@ -0,0 +1,54 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import kotlinx.coroutines.CompletableDeferred + +/** + * Implements [onClose], [onError] and [onMessage] functions of [Transport] providing + * corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation. + */ +@Suppress("PropertyName") +public abstract class AbstractTransport : Transport { + protected var _onClose: (() -> Unit) = {} + private set + protected var _onError: ((Throwable) -> Unit) = {} + private set + + // to not skip messages + private val _onMessageInitialized = CompletableDeferred() + protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { + _onMessageInitialized.await() + _onMessage.invoke(it) + } + private set + + override fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + override fun onError(block: (Throwable) -> Unit) { + val old = _onError + _onError = { e -> + old(e) + block(e) + } + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) { + true -> _onMessage + false -> { _ -> } + } + + _onMessage = { message -> + old(message) + block(message) + } + + _onMessageInitialized.complete(Unit) + } +} diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 6eedfe62..dc15c058 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -116,6 +116,7 @@ internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) } * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ +@Suppress("TooManyFunctions") public abstract class Protocol(@PublishedApi internal val options: ProtocolOptions?) { public var transport: Transport? = null private set @@ -190,7 +191,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio /** * Attaches to the given transport, starts it, and starts listening for messages. * - * The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + * The Protocol object assumes ownership of the Transport, + * replacing any callbacks that have already been set, + * and expects that it is the only user of the Transport instance going forward. */ public open suspend fun connect(transport: Transport) { this.transport = transport @@ -237,6 +240,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio logger.trace { "No handler found for notification: ${notification.method}" } return } + @Suppress("TooGenericExceptionCaught") try { handler(notification) } catch (cause: Throwable) { @@ -252,6 +256,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio if (handler === null) { logger.trace { "No handler found for request: ${request.method}" } + @Suppress("TooGenericExceptionCaught") try { transport?.send( JSONRPCResponse( @@ -269,6 +274,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio return } + @Suppress("TooGenericExceptionCaught") try { val result = handler(request, RequestHandlerExtra()) logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } @@ -303,7 +309,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio private fun onProgress(notification: ProgressNotification) { logger.trace { - "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" + "Received progress notification: token=${notification.params.progressToken}, " + + "progress=${notification.params.progress}/${notification.params.total}" } val progress = notification.params.progress val total = notification.params.total @@ -392,7 +399,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio public suspend fun request(request: Request, options: RequestOptions? = null): T { logger.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() - val transport = transport ?: throw Error("Not connected") + val transport = checkNotNull(transport) { + "No transport connected" + } if (this@Protocol.options?.enforceStrictCapabilities == true) { assertCapabilityForMethod(request.method) @@ -420,6 +429,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio return@put } + @Suppress("TooGenericExceptionCaught") try { @Suppress("UNCHECKED_CAST") result.complete(response!!.result as T) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt index ba460f94..2ae5b700 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt @@ -1,7 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import kotlinx.coroutines.CompletableDeferred /** * Describes the minimal contract for MCP transport that a client or server can communicate over. @@ -47,53 +46,3 @@ public interface Transport { */ public fun onMessage(block: suspend (JSONRPCMessage) -> Unit) } - -/** - * Implements [onClose], [onError] and [onMessage] functions of [Transport] providing - * corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation. - */ -@Suppress("PropertyName") -public abstract class AbstractTransport : Transport { - protected var _onClose: (() -> Unit) = {} - private set - protected var _onError: ((Throwable) -> Unit) = {} - private set - - // to not skip messages - private val _onMessageInitialized = CompletableDeferred() - protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { - _onMessageInitialized.await() - _onMessage.invoke(it) - } - private set - - override fun onClose(block: () -> Unit) { - val old = _onClose - _onClose = { - old() - block() - } - } - - override fun onError(block: (Throwable) -> Unit) { - val old = _onError - _onError = { e -> - old(e) - block(e) - } - } - - override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) { - true -> _onMessage - false -> { _ -> } - } - - _onMessage = { message -> - old(message) - block(message) - } - - _onMessageInitialized.complete(Unit) - } -} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 99f7aa84..d6a865df 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -257,7 +257,7 @@ public open class Server( title: String? = null, outputSchema: Tool.Output? = null, toolAnnotations: ToolAnnotations? = null, - @Suppress("LocalVariableName") _meta: JsonObject? = null, + @Suppress("LocalVariableName", "FunctionParameterNaming") _meta: JsonObject? = null, handler: suspend (CallToolRequest) -> CallToolResult, ) { val tool = Tool( diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index b8532e1e..ec44904d 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -14,6 +14,7 @@ kotlin { implementation(dependencies.platform(libs.ktor.bom)) implementation(project(":kotlin-sdk")) implementation(kotlin("test")) + implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) implementation(libs.kotlin.logging) implementation(libs.kotlinx.coroutines.test) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt index d5644bbc..65e6d65d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin +import io.kotest.matchers.throwable.shouldHaveMessage import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.PromptArgument @@ -132,7 +133,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { ) } - // complext prompt + // complex prompt server.addPrompt( name = complexPromptName, description = complexPromptDescription, @@ -152,8 +153,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { // validate required arguments val requiredArgs = listOf("arg1", "arg2", "arg3") for (argName in requiredArgs) { - if (request.arguments?.get(argName) == null) { - throw IllegalArgumentException("Missing required argument: $argName") + require(request.arguments?.get(argName) != null) { + "Missing required argument: $argName" } } @@ -665,9 +666,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { } } - val msg = exception.message ?: "" - val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})" - - assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt") + exception shouldHaveMessage + "JSONRPCError(code=InternalError, message=Prompt not found: $nonExistentPromptName, data={})" } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt index ffaff6b4..4e18d683 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt @@ -1,5 +1,15 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.collections.shouldBeEmpty +import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.kotest.matchers.throwable.shouldHaveMessage +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.Method @@ -8,6 +18,7 @@ import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals @@ -21,9 +32,51 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { ) @Test - fun `removePrompt should remove a prompt`() = runTest { + fun `Should list no prompts by default`() = runTest { + client.listPrompts() shouldNotBeNull { + prompts.shouldBeEmpty() + } + } + + @Test + fun `Should add a prompt`() = runTest { // Add a prompt - val testPrompt = Prompt("test-prompt", "Test Prompt", null) + val testPrompt = Prompt( + name = "test-prompt-with-custom-handler", + description = "Test Prompt", + arguments = null, + ) + val expectedPromptResult = GetPromptResult( + description = "Test prompt description", + messages = listOf(), + ) + + server.addPrompt(testPrompt) { + expectedPromptResult + } + + client.getPrompt( + GetPromptRequest( + name = "test-prompt-with-custom-handler", + arguments = null, + ), + ) shouldBe expectedPromptResult + + client.listPrompts() shouldNotBeNull { + prompts shouldContainExactly listOf(testPrompt) + nextCursor shouldBe null + _meta shouldBe EmptyJsonObject + } + } + + @Test + fun `Should remove a prompt`() = runTest { + // given + val testPrompt = Prompt( + name = "test-prompt-to-remove", + description = "Test Prompt", + arguments = null, + ) server.addPrompt(testPrompt) { GetPromptResult( description = "Test prompt description", @@ -31,15 +84,33 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { ) } - // Remove the prompt + client.listPrompts() shouldNotBeNull { + prompts shouldContain testPrompt + } + + // when val result = server.removePrompt(testPrompt.name) - // Verify the prompt was removed + // then assertTrue(result, "Prompt should be removed successfully") + val mcpException = shouldThrow { + client.getPrompt( + GetPromptRequest( + name = testPrompt.name, + arguments = null, + ), + ) + } + mcpException shouldHaveMessage + "JSONRPCError(code=InternalError, message=Prompt not found: ${testPrompt.name}, data={})" + + client.listPrompts() shouldNotBeNull { + prompts.firstOrNull { it.name == testPrompt.name } shouldBe null + } } @Test - fun `removePrompts should remove multiple prompts and send notification`() = runTest { + fun `Should remove multiple prompts and send notification`() = runTest { // Add prompts val testPrompt1 = Prompt("test-prompt-1", "Test Prompt 1", null) val testPrompt2 = Prompt("test-prompt-2", "Test Prompt 2", null) @@ -56,11 +127,17 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { ) } + client.listPrompts() shouldNotBeNull { + prompts shouldHaveSize 2 + } // Remove the prompts val result = server.removePrompts(listOf(testPrompt1.name, testPrompt2.name)) // Verify the prompts were removed assertEquals(2, result, "Both prompts should be removed") + client.listPrompts() shouldNotBeNull { + prompts.shouldBeEmpty() + } } @Test @@ -80,21 +157,55 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { assertFalse(promptListChangedNotificationReceived, "No notification should be sent when prompt doesn't exist") } - @Test - fun `removePrompt should throw when prompts capability is not supported`() = runTest { + @Nested + inner class NoPromptsCapabilitiesTests { // Create server without prompts capability - val serverOptions = ServerOptions( - capabilities = ServerCapabilities(), - ) - val server = Server( + val serverWithoutPrompts = Server( Implementation(name = "test server", version = "1.0"), - serverOptions, + ServerOptions( + capabilities = ServerCapabilities(), + ), ) - // Verify that removing a prompt throws an exception - val exception = assertThrows { - server.removePrompt("test-prompt") + @Test + fun `RemovePrompt should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.removePrompt("test-prompt") + } + assertEquals("Server does not support prompts capability.", exception.message) + } + + @Test + fun `Remove Prompts should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.removePrompts(emptyList()) + } + assertEquals("Server does not support prompts capability.", exception.message) + } + + @Test + fun `Add Prompt should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.addPrompt(name = "test-prompt") { + GetPromptResult( + description = "Test prompt description", + messages = listOf(), + ) + } + } + assertEquals("Server does not support prompts capability.", exception.message) + } + + @Test + fun `Add Prompts should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.addPrompts(emptyList()) + } + assertEquals("Server does not support prompts capability.", exception.message) } - assertEquals("Server does not support prompts capability.", exception.message) } }