diff --git a/build.gradle.kts b/build.gradle.kts index 616b514e..f24e17ae 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -46,7 +46,7 @@ kover { } verify { rule { - minBound(65) + minBound(75) } } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 0283a646..1f76a52f 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -18,7 +18,7 @@ kotlinx-io = "0.8.0" ktor = "3.2.3" logging = "7.0.13" slf4j = "2.0.17" -kotest = "6.0.4" +kotest = "5.9.1" # for JVM 1.8 awaitility = "4.3.0" mokksy = "0.6.1" diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 4f230ab4..a667ce2f 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -1,6 +1,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/Annotations { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/Annotations$Companion; + public fun ()V public fun (Ljava/util/List;Lkotlin/time/Instant;Ljava/lang/Double;)V + public synthetic fun (Ljava/util/List;Lkotlin/time/Instant;Ljava/lang/Double;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/util/List; public final fun component2 ()Lkotlin/time/Instant; public final fun component3 ()Ljava/lang/Double; @@ -1589,12 +1591,11 @@ public final class io/modelcontextprotocol/kotlin/sdk/LoggingMessageNotification public final fun serializer ()Lkotlinx/serialization/KSerializer; } -public final class io/modelcontextprotocol/kotlin/sdk/McpError : java/lang/Exception { +public final class io/modelcontextprotocol/kotlin/sdk/McpException : java/lang/Exception { public fun (ILjava/lang/String;Lkotlinx/serialization/json/JsonObject;)V public synthetic fun (ILjava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getCode ()I public final fun getData ()Lkotlinx/serialization/json/JsonObject; - public fun getMessage ()Ljava/lang/String; } public abstract interface class io/modelcontextprotocol/kotlin/sdk/Method { @@ -1699,7 +1700,9 @@ public final class io/modelcontextprotocol/kotlin/sdk/ModelHint$Companion { public final class io/modelcontextprotocol/kotlin/sdk/ModelPreferences { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ModelPreferences$Companion; + public fun ()V public fun (Ljava/util/List;Ljava/lang/Double;Ljava/lang/Double;Ljava/lang/Double;)V + public synthetic fun (Ljava/util/List;Ljava/lang/Double;Ljava/lang/Double;Ljava/lang/Double;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getCostPriority ()Ljava/lang/Double; public final fun getHints ()Ljava/util/List; public final fun getIntelligencePriority ()Ljava/lang/Double; @@ -1883,6 +1886,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/ProgressNotification$Param public final class io/modelcontextprotocol/kotlin/sdk/Prompt { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/Prompt$Companion; public fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getArguments ()Ljava/util/List; public final fun getDescription ()Ljava/lang/String; public final fun getName ()Ljava/lang/String; @@ -2472,6 +2476,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/Role$Companion { public final class io/modelcontextprotocol/kotlin/sdk/Root { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/Root$Companion; public fun (Ljava/lang/String;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; public final fun component2 ()Ljava/lang/String; public final fun copy (Ljava/lang/String;Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/Root; @@ -2995,6 +3000,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/Tool$Output$Companion { public final class io/modelcontextprotocol/kotlin/sdk/ToolAnnotations { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/ToolAnnotations$Companion; + public fun ()V public fun (Ljava/lang/String;Ljava/lang/Boolean;Ljava/lang/Boolean;Ljava/lang/Boolean;Ljava/lang/Boolean;)V public synthetic fun (Ljava/lang/String;Ljava/lang/Boolean;Ljava/lang/Boolean;Ljava/lang/Boolean;Ljava/lang/Boolean;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/exceptions.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/exceptions.kt new file mode 100644 index 00000000..455c6898 --- /dev/null +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/exceptions.kt @@ -0,0 +1,29 @@ +@file:Suppress("unused", "EnumEntryName") + +package io.modelcontextprotocol.kotlin.sdk + +import kotlinx.serialization.json.JsonObject + +@Deprecated("Use McpException instead", ReplaceWith("McpException")) +public typealias McpError = McpException + +/** + * Represents an error specific to the MCP protocol. + * + * @property code The error code. + * @property message The error message. + * @property data Additional error data as a JSON object. + */ +public class McpException(public val code: Int, message: String, public val data: JsonObject = EmptyJsonObject) : + Exception("MCP error $code: \"$message\"") + +/** + * Converts a `JSONRPCError` instance to an [McpException] instance. + * + * @return An [McpException] containing the code, message, and data from the `JSONRPCError`. + */ +internal fun JSONRPCError.toMcpException(): McpException = McpException( + code = this.code.code, + message = this.message, + data = this.data, +) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt new file mode 100644 index 00000000..55e09da8 --- /dev/null +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt @@ -0,0 +1,54 @@ +package io.modelcontextprotocol.kotlin.sdk.shared + +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import kotlinx.coroutines.CompletableDeferred + +/** + * Implements [onClose], [onError] and [onMessage] functions of [Transport] providing + * corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation. + */ +@Suppress("PropertyName") +public abstract class AbstractTransport : Transport { + protected var _onClose: (() -> Unit) = {} + private set + protected var _onError: ((Throwable) -> Unit) = {} + private set + + // to not skip messages + private val _onMessageInitialized = CompletableDeferred() + protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { + _onMessageInitialized.await() + _onMessage.invoke(it) + } + private set + + override fun onClose(block: () -> Unit) { + val old = _onClose + _onClose = { + old() + block() + } + } + + override fun onError(block: (Throwable) -> Unit) { + val old = _onError + _onError = { e -> + old(e) + block(e) + } + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) { + true -> _onMessage + false -> { _ -> } + } + + _onMessage = { message -> + old(message) + block(message) + } + + _onMessageInitialized.complete(Unit) + } +} diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 6eedfe62..41786f69 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -8,7 +8,7 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCError import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse -import io.modelcontextprotocol.kotlin.sdk.McpError +import io.modelcontextprotocol.kotlin.sdk.McpException import io.modelcontextprotocol.kotlin.sdk.Method import io.modelcontextprotocol.kotlin.sdk.Notification import io.modelcontextprotocol.kotlin.sdk.PingRequest @@ -19,6 +19,7 @@ import io.modelcontextprotocol.kotlin.sdk.RequestId import io.modelcontextprotocol.kotlin.sdk.RequestResult import io.modelcontextprotocol.kotlin.sdk.fromJSON import io.modelcontextprotocol.kotlin.sdk.toJSON +import io.modelcontextprotocol.kotlin.sdk.toMcpException import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic import kotlinx.atomicfu.getAndUpdate @@ -97,7 +98,7 @@ public data class RequestOptions( val onProgress: ProgressCallback? = null, /** - * A timeout for this request. If exceeded, an McpError with code `RequestTimeout` + * A timeout for this request. If exceeded, an [McpException] with code `RequestTimeout` * will be raised from request(). * * If not specified, `DEFAULT_REQUEST_TIMEOUT` will be used as the timeout. @@ -116,6 +117,7 @@ internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) } * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ +@Suppress("TooManyFunctions") public abstract class Protocol(@PublishedApi internal val options: ProtocolOptions?) { public var transport: Transport? = null private set @@ -190,7 +192,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio /** * Attaches to the given transport, starts it, and starts listening for messages. * - * The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + * The Protocol object assumes ownership of the Transport, + * replacing any callbacks that have already been set, + * and expects that it is the only user of the Transport instance going forward. */ public open suspend fun connect(transport: Transport) { this.transport = transport @@ -222,7 +226,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio transport = null onClose() - val error = McpError(ErrorCode.Defined.ConnectionClosed.code, "Connection closed") + val error = McpException(ErrorCode.Defined.ConnectionClosed.code, "Connection closed") for (handler in handlersToNotify) { handler(null, error) } @@ -237,6 +241,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio logger.trace { "No handler found for notification: ${notification.method}" } return } + @Suppress("TooGenericExceptionCaught") try { handler(notification) } catch (cause: Throwable) { @@ -252,6 +257,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio if (handler === null) { logger.trace { "No handler found for request: ${request.method}" } + @Suppress("TooGenericExceptionCaught") try { transport?.send( JSONRPCResponse( @@ -268,7 +274,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } return } - + @Suppress("TooGenericExceptionCaught") try { val result = handler(request, RequestHandlerExtra()) logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" } @@ -303,7 +309,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio private fun onProgress(notification: ProgressNotification) { logger.trace { - "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" + "Received progress notification: token=${notification.params.progressToken}, " + + "progress=${notification.params.progress}/${notification.params.total}" } val progress = notification.params.progress val total = notification.params.total @@ -347,7 +354,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio handler(response, null) } else { check(error != null) - val error = McpError( + val error = McpException( error.code.code, error.message, error.data, @@ -392,7 +399,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio public suspend fun request(request: Request, options: RequestOptions? = null): T { logger.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() - val transport = transport ?: throw Error("Not connected") + val transport = transport ?: throw McpException(ErrorCode.Defined.ConnectionClosed.code, "Not connected") if (this@Protocol.options?.enforceStrictCapabilities == true) { assertCapabilityForMethod(request.method) @@ -415,11 +422,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio return@put } - if (response?.error != null) { - result.completeExceptionally(IllegalStateException(response.error.toString())) + response?.error?.let { + result.completeExceptionally(it.toMcpException()) return@put } + @Suppress("TooGenericExceptionCaught") try { @Suppress("UNCHECKED_CAST") result.complete(response!!.result as T) @@ -459,7 +467,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } catch (cause: TimeoutCancellationException) { logger.error { "Request timed out after ${timeout.inWholeMilliseconds}ms: ${request.method}" } cancel( - McpError( + McpException( ErrorCode.Defined.RequestTimeout.code, "Request timed out", JsonObject(mutableMapOf("timeout" to JsonPrimitive(timeout.inWholeMilliseconds))), diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt index ba460f94..2ae5b700 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Transport.kt @@ -1,7 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage -import kotlinx.coroutines.CompletableDeferred /** * Describes the minimal contract for MCP transport that a client or server can communicate over. @@ -47,53 +46,3 @@ public interface Transport { */ public fun onMessage(block: suspend (JSONRPCMessage) -> Unit) } - -/** - * Implements [onClose], [onError] and [onMessage] functions of [Transport] providing - * corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation. - */ -@Suppress("PropertyName") -public abstract class AbstractTransport : Transport { - protected var _onClose: (() -> Unit) = {} - private set - protected var _onError: ((Throwable) -> Unit) = {} - private set - - // to not skip messages - private val _onMessageInitialized = CompletableDeferred() - protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { - _onMessageInitialized.await() - _onMessage.invoke(it) - } - private set - - override fun onClose(block: () -> Unit) { - val old = _onClose - _onClose = { - old() - block() - } - } - - override fun onError(block: (Throwable) -> Unit) { - val old = _onError - _onError = { e -> - old(e) - block(e) - } - } - - override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) { - true -> _onMessage - false -> { _ -> } - } - - _onMessage = { message -> - old(message) - block(message) - } - - _onMessageInitialized.complete(Unit) - } -} diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index 60adcca4..9ef9b719 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -52,7 +52,7 @@ public sealed interface WithMeta { * The protocol reserves this result property * to allow clients and servers to attach additional metadata to their responses. */ - @Suppress("PropertyName") + @Suppress("PropertyName", "VariableNaming") public val _meta: JsonObject public companion object { @@ -271,6 +271,7 @@ public sealed interface ErrorCode { public val code: Int @Serializable + @Suppress("MagicNumber") public enum class Defined(override val code: Int) : ErrorCode { // SDK error codes ConnectionClosed(-1), @@ -302,10 +303,12 @@ public data class JSONRPCError(val code: ErrorCode, val message: String, val dat public sealed interface NotificationParams : WithMeta /* Cancellation */ + /** * This notification can be sent by either side to indicate that it is cancelling a previously issued request. * - * The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. + * The request SHOULD still be in-flight, but due to communication latency, + * it is always possible that this notification MAY arrive after the request has already finished. * * This notification indicates that the result will be unused, so any associated processing SHOULD cease. * @@ -334,6 +337,7 @@ public data class CancelledNotification(override val params: Params) : } /* Initialization */ + /** * Describes the name and version of an MCP implementation. */ @@ -507,7 +511,10 @@ public data class ServerCapabilities( @Serializable public data class InitializeResult( /** - * The version of the Model Context Protocol that the server wants to use. This may not match the version that the client requested. If the client cannot support this version, it MUST disconnect. + * The version of the Model Context Protocol that the server wants to use. + * + * This may not match the version that the client requested. + * If the client cannot support this version, it MUST disconnect. */ val protocolVersion: String = LATEST_PROTOCOL_VERSION, val capabilities: ServerCapabilities = ServerCapabilities(), @@ -531,6 +538,7 @@ public data class InitializedNotification(override val params: Params = Params() } /* Ping */ + /** * A ping, issued by either the server or the client, to check that the other party is still alive. * The receiver must promptly respond, or else it may be disconnected. @@ -564,6 +572,7 @@ public sealed interface ProgressBase { } /* Progress notifications */ + /** * Represents a progress notification. * @@ -623,6 +632,7 @@ public data class ProgressNotification(override val params: Params) : } /* Pagination */ + /** * Represents a request supporting pagination. */ @@ -650,6 +660,7 @@ public sealed interface PaginatedResult : RequestResult { } /* Resources */ + /** * The contents of a specific resource or sub-resource. */ @@ -669,7 +680,8 @@ public sealed interface ResourceContents { /** * Represents the text contents of a resource. * - * @property text The text of the item. This must only be set if the item can actually be represented as text (not binary data). + * @property text The text of the item. This must only be set if the item can actually be represented as text + * (not binary data). */ @Serializable public data class TextResourceContents(val text: String, override val uri: String, override val mimeType: String?) : @@ -753,7 +765,9 @@ public data class ResourceTemplate( */ val description: String?, /** - * The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type. + * The MIME type for all resources that match this template. + * + * This should only be included if all resources matching this template have the same type. */ val mimeType: String?, /** @@ -845,12 +859,15 @@ public data class ResourceListChangedNotification(override val params: Params = } /** - * Sent from the client to request resources/updated notifications from the server whenever a particular resource changes. + * Sent from the client to request resources/updated notifications from the server + * whenever a particular resource changes. */ @Serializable public data class SubscribeRequest( /** - * The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it. + * The URI of the resource to subscribe to. + * + * The URI can use any protocol; it is up to the server how to interpret it. */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, @@ -860,7 +877,8 @@ public data class SubscribeRequest( } /** - * Sent from the client to request cancellation of resources/updated notifications from the server. This should follow a previous resources/subscribe request. + * Sent from the client to request cancellation of resources/updated notifications from the server. + * This should follow a previous resources/subscribe request. */ @Serializable public data class UnsubscribeRequest( @@ -875,7 +893,8 @@ public data class UnsubscribeRequest( } /** - * A notification from the server to the client, informing it that a resource has changed and may need to be read again. This should only be sent if the client previously sent a resources/subscribe request. + * A notification from the server to the client, informing it that a resource has changed and may need to be read again. + * This should only be sent if the client previously sent a resources/subscribe request. */ @Serializable public data class ResourceUpdatedNotification(override val params: Params) : ServerNotification { @@ -884,7 +903,8 @@ public data class ResourceUpdatedNotification(override val params: Params) : Ser @Serializable public data class Params( /** - * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + * The URI of the resource that has been updated. + * This might be a sub-resource of the one that the client actually subscribed to. */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, @@ -892,6 +912,7 @@ public data class ResourceUpdatedNotification(override val params: Params) : Ser } /* Prompts */ + /** * Describes an argument that a prompt can accept. */ @@ -923,11 +944,11 @@ public class Prompt( /** * An optional description of what this prompt provides */ - public val description: String?, + public val description: String? = null, /** * A list of arguments to use for templating the prompt. */ - public val arguments: List?, + public val arguments: List? = null, ) /** @@ -966,7 +987,7 @@ public data class GetPromptRequest( /** * Arguments to use for templating the prompt. */ - val arguments: Map?, + val arguments: Map? = null, override val _meta: JsonObject = EmptyJsonObject, ) : ClientRequest, @@ -983,7 +1004,7 @@ public sealed interface PromptMessageContent { } /** - * Represents prompt message content that is either text, image or audio. + * Represents prompt message content that is either text, image, or audio. */ @Serializable(with = PromptMessageContentMultimodalPolymorphicSerializer::class) public sealed interface PromptMessageContentMultimodal : PromptMessageContent @@ -1095,7 +1116,7 @@ public data class EmbeddedResource( /** * Enum representing the role of a participant. */ -@Suppress("EnumEntryName") +@Suppress("EnumEntryName", "EnumNaming") @Serializable public enum class Role { user, @@ -1111,19 +1132,19 @@ public data class Annotations( /** * Describes who the intended customer of this object or data is. */ - val audience: List?, + val audience: List? = null, /** * The moment the resource was last modified. */ @OptIn(ExperimentalTime::class) - val lastModified: Instant?, + val lastModified: Instant? = null, /** * Describes how important this data is for operating the server. * - * A value of 1.0 means "most important", and indicates that the data is effectively required, - * while 0.0 means "less important", and indicates that the data is entirely optional. + * A value of 1.0 means "most important" and indicates that the data is effectively required, + * while 0.0 means "less important" and indicates that the data is entirely optional. */ - val priority: Double?, + val priority: Double? = null, ) { init { require(priority == null || priority in 0.0..1.0) { "Priority must be between 0.0 and 1.0" } @@ -1144,7 +1165,7 @@ public class GetPromptResult( /** * An optional description for the prompt. */ - public val description: String?, + public val description: String? = null, public val messages: List, override val _meta: JsonObject = EmptyJsonObject, ) : ServerResult @@ -1162,6 +1183,7 @@ public data class PromptListChangedNotification(override val params: Params = Pa } /* Tools */ + /** * Additional properties describing a Tool to clients. * @@ -1177,7 +1199,7 @@ public data class ToolAnnotations( /** * A human-readable title for the tool. */ - val title: String?, + val title: String? = null, /** * If true, the tool does not modify its environment. * @@ -1195,7 +1217,7 @@ public data class ToolAnnotations( val destructiveHint: Boolean? = true, /** * If true, calling the tool repeatedly with the same arguments - * will have no additional effect on the its environment. + * will have no additional effect on its environment. * * (This property is meaningful only when `readOnlyHint == false`) * @@ -1225,11 +1247,11 @@ public data class Tool( /** * The title of the tool. */ - val title: String?, + val title: String? = null, /** * A human-readable description of the tool. */ - val description: String?, + val description: String? = null, /** * A JSON object defining the expected parameters for the tool. */ @@ -1237,11 +1259,11 @@ public data class Tool( /** * An optional JSON object defining the expected output schema for the tool. */ - val outputSchema: Output?, + val outputSchema: Output? = null, /** * Optional additional tool information. */ - val annotations: ToolAnnotations?, + val annotations: ToolAnnotations? = null, /** * Optional metadata for the tool. @@ -1345,10 +1367,11 @@ public data class ToolListChangedNotification(override val params: Params = Para } /* Logging */ + /** * The severity of a log message. */ -@Suppress("EnumEntryName") +@Suppress("EnumEntryName", "EnumNaming") @Serializable public enum class LoggingLevel { debug, @@ -1393,7 +1416,9 @@ public data class LoggingMessageNotification(override val params: Params) : Serv @Serializable public data class SetLevelRequest( /** - * The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/logging/message. + * The level of logging that the client wants to receive from the server. + * The server should send all logs at this level and higher (i.e., more severe) + * to the client as notifications/logging/message. */ val level: LoggingLevel, override val _meta: JsonObject = EmptyJsonObject, @@ -1404,6 +1429,7 @@ public data class LoggingMessageNotification(override val params: Params) : Serv } /* Sampling */ + /** * Hints to use for model selection. */ @@ -1424,19 +1450,19 @@ public class ModelPreferences( /** * Optional hints to use for model selection. */ - public val hints: List?, + public val hints: List? = null, /** * How much to prioritize cost when selecting a model. */ - public val costPriority: Double?, + public val costPriority: Double? = null, /** * How much to prioritize sampling speed (latency) when selecting a model. */ - public val speedPriority: Double?, + public val speedPriority: Double? = null, /** * How much to prioritize intelligence and capabilities when selecting a model. */ - public val intelligencePriority: Double?, + public val intelligencePriority: Double? = null, ) { init { require(costPriority == null || costPriority in 0.0..1.0) { @@ -1471,17 +1497,20 @@ public data class CreateMessageRequest( /** * An optional system prompt the server wants to use it for sampling. The client MAY modify or omit this prompt. */ - val systemPrompt: String?, + val systemPrompt: String? = null, /** - * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request. + * A request to include context from one or more MCP servers (including the caller), + * to be attached to the prompt. The client MAY ignore this request. */ - val includeContext: IncludeContext?, - val temperature: Double?, + val includeContext: IncludeContext? = null, + val temperature: Double? = null, /** - * The maximum number of tokens to sample, as requested by the server. The client MAY choose to sample fewer tokens than requested. + * The maximum number of tokens to sample, as requested by the server. + * + * The client MAY choose to sample fewer tokens than requested. */ val maxTokens: Int, - val stopSequences: List?, + val stopSequences: List? = null, /** * Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific. */ @@ -1489,13 +1518,14 @@ public data class CreateMessageRequest( /** * The server's preferences for which model to select. */ - val modelPreferences: ModelPreferences?, + val modelPreferences: ModelPreferences? = null, override val _meta: JsonObject = EmptyJsonObject, ) : ServerRequest, WithMeta { override val method: Method = Method.Defined.SamplingCreateMessage @Serializable + @Suppress("EnumEntryName", "EnumNaming") public enum class IncludeContext { none, thisServer, allServers } } @@ -1623,7 +1653,6 @@ public data class CompleteRequest( @Serializable public data class CompleteResult(val completion: Completion, override val _meta: JsonObject = EmptyJsonObject) : ServerResult { - @Suppress("CanBeParameter") @Serializable public class Completion( /** @@ -1631,15 +1660,19 @@ public data class CompleteResult(val completion: Completion, override val _meta: */ public val values: List, /** - * The total number of completion options available. This can exceed the number of values actually sent in the response. + * The total number of completion options available. + * + * This can exceed the number of values actually sent in the response. */ public val total: Int?, /** - * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. + * Indicates whether there are additional completion options beyond those provided in the current response, + * even if the exact total is unknown. */ public val hasMore: Boolean?, ) { init { + @Suppress("MagicNumber") require(values.size <= 100) { "'values' field must not exceed 100 items" } @@ -1648,6 +1681,7 @@ public data class CompleteResult(val completion: Completion, override val _meta: } /* Roots */ + /** * Represents a root directory or file that the server can operate on. */ @@ -1661,7 +1695,7 @@ public data class Root( /** * An optional name for the root. */ - val name: String?, + val name: String? = null, ) { init { require(uri.startsWith("file://")) { @@ -1699,7 +1733,7 @@ public data class RootsListChangedNotification(override val params: Params = Par } /** - * Sent from the server to create an elicitation from the client. + * Sent from the server to create elicitation from the client. */ @Serializable public data class CreateElicitationRequest( @@ -1738,17 +1772,6 @@ public data class CreateElicitationResult( } @Serializable + @Suppress("EnumEntryName", "EnumNaming") public enum class Action { accept, decline, cancel } } - -/** - * Represents an error specific to the MCP protocol. - * - * @property code The error code. - * @property message The error message. - * @property data Additional error data as a JSON object. - */ -public class McpError(public val code: Int, message: String, public val data: JsonObject = EmptyJsonObject) : - Exception() { - override val message: String = "MCP error $code: $message" -} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 99f7aa84..d286dc1f 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -257,7 +257,7 @@ public open class Server( title: String? = null, outputSchema: Tool.Output? = null, toolAnnotations: ToolAnnotations? = null, - @Suppress("LocalVariableName") _meta: JsonObject? = null, + @Suppress("LocalVariableName", "FunctionParameterNaming") _meta: JsonObject? = null, handler: suspend (CallToolRequest) -> CallToolResult, ) { val tool = Tool( @@ -468,7 +468,7 @@ public open class Server( } // --- Internal Handlers --- - private suspend fun handleListTools(): ListToolsResult { + private fun handleListTools(): ListToolsResult { val toolList = tools.values.map { it.tool } return ListToolsResult(tools = toolList, nextCursor = null) } @@ -501,7 +501,7 @@ public open class Server( } } - private suspend fun handleListPrompts(): ListPromptsResult { + private fun handleListPrompts(): ListPromptsResult { logger.debug { "Handling list prompts request" } return ListPromptsResult(prompts = prompts.values.map { it.prompt }) } @@ -510,13 +510,13 @@ public open class Server( logger.debug { "Handling get prompt request for: ${request.name}" } val prompt = promptRegistry.get(request.name) ?: run { - logger.error { "Prompt not found: ${request.name}" } - throw IllegalArgumentException("Prompt not found: ${request.name}") + logger.error { "Prompt not found: '${request.name}'" } + throw IllegalArgumentException("Prompt not found: '${request.name}'") } return prompt.messageProvider(request) } - private suspend fun handleListResources(): ListResourcesResult { + private fun handleListResources(): ListResourcesResult { logger.debug { "Handling list resources request" } return ListResourcesResult(resources = resources.values.map { it.resource }) } @@ -531,7 +531,7 @@ public open class Server( return resource.readHandler(request) } - private suspend fun handleListResourceTemplates(): ListResourceTemplatesResult { + private fun handleListResourceTemplates(): ListResourceTemplatesResult { // If you have resource templates, return them here. For now, return empty. return ListResourceTemplatesResult(listOf()) } diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index b8532e1e..ec44904d 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -14,6 +14,7 @@ kotlin { implementation(dependencies.platform(libs.ktor.bom)) implementation(project(":kotlin-sdk")) implementation(kotlin("test")) + implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) implementation(libs.kotlin.logging) implementation(libs.kotlinx.coroutines.test) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt index d5644bbc..6465bb16 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -1,7 +1,11 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin +import io.kotest.matchers.shouldBe +import io.kotest.matchers.throwable.shouldHaveMessage +import io.modelcontextprotocol.kotlin.sdk.ErrorCode import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.McpException import io.modelcontextprotocol.kotlin.sdk.PromptArgument import io.modelcontextprotocol.kotlin.sdk.PromptMessage import io.modelcontextprotocol.kotlin.sdk.Role @@ -132,7 +136,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { ) } - // complext prompt + // complex prompt server.addPrompt( name = complexPromptName, description = complexPromptDescription, @@ -152,8 +156,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { // validate required arguments val requiredArgs = listOf("arg1", "arg2", "arg3") for (argName in requiredArgs) { - if (request.arguments?.get(argName) == null) { - throw IllegalArgumentException("Missing required argument: $argName") + require(request.arguments?.get(argName) != null) { + "Missing required argument: $argName" } } @@ -373,7 +377,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { ) // test missing required arg - val exception = assertThrows { + val exception = assertThrows { runBlocking { client.getPrompt( GetPromptRequest( @@ -391,7 +395,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { ) // test with no args - val exception2 = assertThrows { + val exception2 = assertThrows { runBlocking { client.getPrompt( GetPromptRequest( @@ -654,7 +658,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { fun testNonExistentPrompt() = runTest { val nonExistentPromptName = "non-existent-prompt" - val exception = assertThrows { + val exception = assertThrows { runBlocking { client.getPrompt( GetPromptRequest( @@ -665,9 +669,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { } } - val msg = exception.message ?: "" - val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})" - - assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt") + exception.code shouldBe ErrorCode.Defined.InternalError.code + exception shouldHaveMessage "MCP error -32603: \"Prompt not found: '$nonExistentPromptName'\"" } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index ad80567a..030bdcd2 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin import io.modelcontextprotocol.kotlin.sdk.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.McpException import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities @@ -202,7 +203,7 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { fun testInvalidResourceUri() = runTest { val invalidUri = "test://nonexistent.txt" - val exception = assertThrows { + val exception = assertThrows { runBlocking { client.readResource(ReadResourceRequest(uri = invalidUri)) } @@ -210,7 +211,7 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { val msg = exception.message ?: "" val expectedMessage = - "JSONRPCError(code=InternalError, message=Resource not found: test://nonexistent.txt, data={})" + "MCP error -32603: \"Resource not found: test://nonexistent.txt\"" assertEquals(expectedMessage, msg, "Unexpected error message for invalid resource URI") } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt index ffaff6b4..d0558d29 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerPromptsTest.kt @@ -1,13 +1,26 @@ package io.modelcontextprotocol.kotlin.sdk.server +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.collections.shouldBeEmpty +import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.kotest.matchers.throwable.shouldHaveMessage +import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.McpException import io.modelcontextprotocol.kotlin.sdk.Method import io.modelcontextprotocol.kotlin.sdk.Prompt import io.modelcontextprotocol.kotlin.sdk.PromptListChangedNotification import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals @@ -21,9 +34,38 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { ) @Test - fun `removePrompt should remove a prompt`() = runTest { + fun `Should list no prompts by default`() = runTest { + client.listPrompts() shouldNotBeNull { + prompts.shouldBeEmpty() + } + } + + @Test + fun `Should add a prompt`() = runTest { // Add a prompt - val testPrompt = Prompt("test-prompt", "Test Prompt", null) + val testPrompt = Prompt(name = "test-prompt-with-custom-handler", description = "Test Prompt") + val expectedPromptResult = GetPromptResult( + description = "Test prompt description", + messages = listOf(), + ) + + server.addPrompt(testPrompt) { + expectedPromptResult + } + + client.getPrompt(GetPromptRequest("test-prompt-with-custom-handler")) shouldBe expectedPromptResult + + client.listPrompts() shouldNotBeNull { + prompts shouldContainExactly listOf(testPrompt) + nextCursor shouldBe null + _meta shouldBe EmptyJsonObject + } + } + + @Test + fun `Should remove a prompt`() = runTest { + // given + val testPrompt = Prompt(name = "test-prompt-to-remove", description = "Test Prompt") server.addPrompt(testPrompt) { GetPromptResult( description = "Test prompt description", @@ -31,15 +73,30 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { ) } - // Remove the prompt + client.listPrompts() shouldNotBeNull { + prompts shouldContain testPrompt + } + + // when val result = server.removePrompt(testPrompt.name) - // Verify the prompt was removed + // then assertTrue(result, "Prompt should be removed successfully") + + val mcpException = shouldThrow { + client.getPrompt(GetPromptRequest(name = testPrompt.name)) + } + mcpException.code shouldBe ErrorCode.Defined.InternalError.code + mcpException shouldHaveMessage + "MCP error -32603: \"Prompt not found: 'test-prompt-to-remove'\"" + + client.listPrompts() shouldNotBeNull { + prompts.firstOrNull { it.name == testPrompt.name } shouldBe null + } } @Test - fun `removePrompts should remove multiple prompts and send notification`() = runTest { + fun `Should remove multiple prompts and send notification`() = runTest { // Add prompts val testPrompt1 = Prompt("test-prompt-1", "Test Prompt 1", null) val testPrompt2 = Prompt("test-prompt-2", "Test Prompt 2", null) @@ -56,11 +113,19 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { ) } + client.listPrompts() shouldNotBeNull { + prompts shouldHaveSize 2 + } + // Remove the prompts val result = server.removePrompts(listOf(testPrompt1.name, testPrompt2.name)) // Verify the prompts were removed assertEquals(2, result, "Both prompts should be removed") + + client.listPrompts() shouldNotBeNull { + prompts.shouldBeEmpty() + } } @Test @@ -80,21 +145,56 @@ class ServerPromptsTest : AbstractServerFeaturesTest() { assertFalse(promptListChangedNotificationReceived, "No notification should be sent when prompt doesn't exist") } - @Test - fun `removePrompt should throw when prompts capability is not supported`() = runTest { + @Nested + inner class NoPromptsCapabilitiesTests { + // Create server without prompts capability - val serverOptions = ServerOptions( - capabilities = ServerCapabilities(), - ) - val server = Server( + val serverWithoutPrompts = Server( Implementation(name = "test server", version = "1.0"), - serverOptions, + ServerOptions( + capabilities = ServerCapabilities(), + ), ) - // Verify that removing a prompt throws an exception - val exception = assertThrows { - server.removePrompt("test-prompt") + @Test + fun `RemovePrompt should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.removePrompt("test-prompt") + } + assertEquals("Server does not support prompts capability.", exception.message) + } + + @Test + fun `Remove Prompts should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.removePrompts(emptyList()) + } + assertEquals("Server does not support prompts capability.", exception.message) + } + + @Test + fun `Add Prompt should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.addPrompt(name = "test-prompt") { + GetPromptResult( + description = "Test prompt description", + messages = listOf(), + ) + } + } + assertEquals("Server does not support prompts capability.", exception.message) + } + + @Test + fun `Add Prompts should throw when prompts capability is not supported`() = runTest { + // Verify that removing a prompt throws an exception + val exception = assertThrows { + serverWithoutPrompts.addPrompts(emptyList()) + } + assertEquals("Server does not support prompts capability.", exception.message) } - assertEquals("Server does not support prompts capability.", exception.message) } }