diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..5c283ed0 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,37 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 +max_line_length = 120 + +[*.json] +indent_size = 2 + +[{*.yaml,*.yml}] +indent_size = 2 + +[*.{kt,kts}] +ij_kotlin_code_style_defaults = KOTLIN_OFFICIAL + +# Disable wildcard imports entirely +ij_kotlin_name_count_to_use_star_import = 2147483647 +ij_kotlin_name_count_to_use_star_import_for_members = 2147483647 +ij_kotlin_packages_to_use_import_on_demand = unset + +ktlint_code_style = intellij_idea +ktlint_experimental = enabled +ktlint_standard_filename = disabled +ktlint_standard_no-empty-first-line-in-class-body = disabled +ktlint_class_signature_rule_force_multiline_when_parameter_count_greater_or_equal_than = 4 +ktlint_function_signature_rule_force_multiline_when_parameter_count_greater_or_equal_than = 4 +ktlint_standard_chain-method-continuation = disabled +ktlint_ignore_back_ticked_identifier = true +ktlint_standard_multiline-expression-wrapping = disabled +ktlint_standard_when-entry-bracing = disabled + +[*/build/**/*] +ktlint = disabled \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 8672d45f..1b697d4f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,4 +1,12 @@ +plugins { + alias(libs.plugins.ktlint) +} + allprojects { group = "io.modelcontextprotocol" version = "0.6.0" -} \ No newline at end of file +} + +subprojects { + apply(plugin = "org.jlleitschuh.gradle.ktlint") +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 5afb3d7a..32ffbb34 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -3,6 +3,7 @@ kotlin = "2.2.0" dokka = "2.0.0" atomicfu = "0.29.0" +ktlint = "13.0.0" # libraries version serialization = "1.9.0" @@ -57,6 +58,7 @@ ktor-serialization-kotlinx-json = { group = "io.ktor", name = "ktor-serializatio [plugins] kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" } +ktlint = { id = "org.jlleitschuh.gradle.ktlint", version.ref = "ktlint" } # Samples kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } diff --git a/kotlin-sdk-client/build.gradle.kts b/kotlin-sdk-client/build.gradle.kts index ffc83cee..aaf6050c 100644 --- a/kotlin-sdk-client/build.gradle.kts +++ b/kotlin-sdk-client/build.gradle.kts @@ -11,9 +11,15 @@ plugins { } kotlin { - iosArm64(); iosX64(); iosSimulatorArm64() - watchosX64(); watchosArm64(); watchosSimulatorArm64() - tvosX64(); tvosArm64(); tvosSimulatorArm64() + iosArm64() + iosX64() + iosSimulatorArm64() + watchosX64() + watchosArm64() + watchosSimulatorArm64() + tvosX64() + tvosArm64() + tvosSimulatorArm64() js { browser() nodejs() @@ -40,4 +46,4 @@ kotlin { } } } -} \ No newline at end of file +} diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 75d0b221..ea63150f 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -80,10 +80,8 @@ public class ClientOptions( * @param clientInfo Information about the client implementation (name, version). * @param options Configuration options for this client. */ -public open class Client( - private val clientInfo: Implementation, - options: ClientOptions = ClientOptions(), -) : Protocol(options) { +public open class Client(private val clientInfo: Implementation, options: ClientOptions = ClientOptions()) : + Protocol(options) { /** * Retrieves the server's reported capabilities after the initialization process completes. @@ -144,13 +142,13 @@ public open class Client( val message = InitializeRequest( protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = capabilities, - clientInfo = clientInfo + clientInfo = clientInfo, ) val result = request(message) if (!SUPPORTED_PROTOCOL_VERSIONS.contains(result.protocolVersion)) { throw IllegalStateException( - "Server's protocol version is not supported: ${result.protocolVersion}" + "Server's protocol version is not supported: ${result.protocolVersion}", ) } @@ -165,11 +163,9 @@ public open class Client( } throw error - } } - override fun assertCapabilityForMethod(method: Method) { when (method) { Method.Defined.LoggingSetLevel -> { @@ -181,7 +177,7 @@ public open class Client( Method.Defined.PromptsGet, Method.Defined.PromptsList, Method.Defined.CompletionComplete, - -> { + -> { if (serverCapabilities?.prompts == null) { throw IllegalStateException("Server does not support prompts (required for $method)") } @@ -192,20 +188,20 @@ public open class Client( Method.Defined.ResourcesRead, Method.Defined.ResourcesSubscribe, Method.Defined.ResourcesUnsubscribe, - -> { + -> { val resCaps = serverCapabilities?.resources ?: error("Server does not support resources (required for $method)") if (method == Method.Defined.ResourcesSubscribe && resCaps.subscribe != true) { throw IllegalStateException( - "Server does not support resource subscriptions (required for $method)" + "Server does not support resource subscriptions (required for $method)", ) } } Method.Defined.ToolsCall, Method.Defined.ToolsList, - -> { + -> { if (serverCapabilities?.tools == null) { throw IllegalStateException("Server does not support tools (required for $method)") } @@ -213,7 +209,7 @@ public open class Client( Method.Defined.Initialize, Method.Defined.Ping, - -> { + -> { // No specific capability required } @@ -228,7 +224,7 @@ public open class Client( Method.Defined.NotificationsRootsListChanged -> { if (capabilities.roots?.listChanged != true) { throw IllegalStateException( - "Client does not support roots list changed notifications (required for $method)" + "Client does not support roots list changed notifications (required for $method)", ) } } @@ -236,7 +232,7 @@ public open class Client( Method.Defined.NotificationsInitialized, Method.Defined.NotificationsCancelled, Method.Defined.NotificationsProgress, - -> { + -> { // Always allowed } @@ -251,7 +247,7 @@ public open class Client( Method.Defined.SamplingCreateMessage -> { if (capabilities.sampling == null) { throw IllegalStateException( - "Client does not support sampling capability (required for $method)" + "Client does not support sampling capability (required for $method)", ) } } @@ -259,7 +255,7 @@ public open class Client( Method.Defined.RootsList -> { if (capabilities.roots == null) { throw IllegalStateException( - "Client does not support roots capability (required for $method)" + "Client does not support roots capability (required for $method)", ) } } @@ -267,7 +263,7 @@ public open class Client( Method.Defined.ElicitationCreate -> { if (capabilities.elicitation == null) { throw IllegalStateException( - "Client does not support elicitation capability (required for $method)" + "Client does not support elicitation capability (required for $method)", ) } } @@ -280,16 +276,14 @@ public open class Client( } } - /** * Sends a ping request to the server to check connectivity. * * @param options Optional request options. * @throws IllegalStateException If the server does not support the ping method (unlikely). */ - public suspend fun ping(options: RequestOptions? = null): EmptyRequestResult { - return request(PingRequest(), options) - } + public suspend fun ping(options: RequestOptions? = null): EmptyRequestResult = + request(PingRequest(), options) /** * Sends a completion request to the server, typically to generate or complete some content. @@ -299,9 +293,8 @@ public open class Client( * @return The completion result returned by the server, or `null` if none. * @throws IllegalStateException If the server does not support prompts or completion. */ - public suspend fun complete(params: CompleteRequest, options: RequestOptions? = null): CompleteResult { - return request(params, options) - } + public suspend fun complete(params: CompleteRequest, options: RequestOptions? = null): CompleteResult = + request(params, options) /** * Sets the logging level on the server. @@ -310,9 +303,8 @@ public open class Client( * @param options Optional request options. * @throws IllegalStateException If the server does not support logging. */ - public suspend fun setLoggingLevel(level: LoggingLevel, options: RequestOptions? = null): EmptyRequestResult { - return request(SetLevelRequest(level), options) - } + public suspend fun setLoggingLevel(level: LoggingLevel, options: RequestOptions? = null): EmptyRequestResult = + request(SetLevelRequest(level), options) /** * Retrieves a prompt by name from the server. @@ -322,9 +314,8 @@ public open class Client( * @return The requested prompt details, or `null` if not found. * @throws IllegalStateException If the server does not support prompts. */ - public suspend fun getPrompt(request: GetPromptRequest, options: RequestOptions? = null): GetPromptResult { - return request(request, options) - } + public suspend fun getPrompt(request: GetPromptRequest, options: RequestOptions? = null): GetPromptResult = + request(request, options) /** * Lists all available prompts from the server. @@ -337,9 +328,7 @@ public open class Client( public suspend fun listPrompts( request: ListPromptsRequest = ListPromptsRequest(), options: RequestOptions? = null, - ): ListPromptsResult { - return request(request, options) - } + ): ListPromptsResult = request(request, options) /** * Lists all available resources from the server. @@ -352,9 +341,7 @@ public open class Client( public suspend fun listResources( request: ListResourcesRequest = ListResourcesRequest(), options: RequestOptions? = null, - ): ListResourcesResult { - return request(request, options) - } + ): ListResourcesResult = request(request, options) /** * Lists resource templates available on the server. @@ -367,9 +354,7 @@ public open class Client( public suspend fun listResourceTemplates( request: ListResourceTemplatesRequest, options: RequestOptions? = null, - ): ListResourceTemplatesResult { - return request(request, options) - } + ): ListResourceTemplatesResult = request(request, options) /** * Reads a resource from the server by its URI. @@ -382,9 +367,7 @@ public open class Client( public suspend fun readResource( request: ReadResourceRequest, options: RequestOptions? = null, - ): ReadResourceResult { - return request(request, options) - } + ): ReadResourceResult = request(request, options) /** * Subscribes to resource changes on the server. @@ -396,9 +379,7 @@ public open class Client( public suspend fun subscribeResource( request: SubscribeRequest, options: RequestOptions? = null, - ): EmptyRequestResult { - return request(request, options) - } + ): EmptyRequestResult = request(request, options) /** * Unsubscribes from resource changes on the server. @@ -410,9 +391,7 @@ public open class Client( public suspend fun unsubscribeResource( request: UnsubscribeRequest, options: RequestOptions? = null, - ): EmptyRequestResult { - return request(request, options) - } + ): EmptyRequestResult = request(request, options) /** * Calls a tool on the server by name, passing the specified arguments. @@ -443,7 +422,7 @@ public open class Client( val request = CallToolRequest( name = name, - arguments = JsonObject(jsonArguments) + arguments = JsonObject(jsonArguments), ) return callTool(request, compatibility, options) } @@ -461,12 +440,10 @@ public open class Client( request: CallToolRequest, compatibility: Boolean = false, options: RequestOptions? = null, - ): CallToolResultBase? { - return if (compatibility) { - request(request, options) - } else { - request(request, options) - } + ): CallToolResultBase? = if (compatibility) { + request(request, options) + } else { + request(request, options) } /** @@ -480,9 +457,7 @@ public open class Client( public suspend fun listTools( request: ListToolsRequest = ListToolsRequest(), options: RequestOptions? = null, - ): ListToolsResult { - return request(request, options) - } + ): ListToolsResult = request(request, options) /** * Registers a single root. @@ -491,10 +466,7 @@ public open class Client( * @param name A human-readable name for the root. * @throws IllegalStateException If the client does not support roots. */ - public fun addRoot( - uri: String, - name: String, - ) { + public fun addRoot(uri: String, name: String) { if (capabilities.roots == null) { logger.error { "Failed to add root '$name': Client does not support roots capability" } throw IllegalStateException("Client does not support roots capability.") diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt index 2ccc223d..ccc1496e 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/KtorClient.kt @@ -12,7 +12,7 @@ import kotlin.time.Duration * * @param urlString Optional URL of the MCP server. * @param reconnectionTime Optional duration to wait before attempting to reconnect. - * @param requestBuilder Optional lambda to configure the HTTP request. + * @param requestBuilder Optional lambda to configure the HTTP request. * @return A [SSEClientTransport] configured for MCP communication. */ public fun HttpClient.mcpSseTransport( @@ -38,8 +38,8 @@ public suspend fun HttpClient.mcpSse( val client = Client( Implementation( name = IMPLEMENTATION_NAME, - version = LIB_VERSION - ) + version = LIB_VERSION, + ), ) client.connect(transport) return client diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index d30f5288..950f37fa 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -139,6 +139,7 @@ public class SseClientTransport( } "endpoint" -> handleEndpoint(event.data.orEmpty()) + else -> handleMessage(event.data.orEmpty()) } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 8ffbb752..583cec63 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -35,10 +35,7 @@ import kotlin.coroutines.CoroutineContext * @param output The output stream where messages are sent. */ @OptIn(ExperimentalAtomicApi::class) -public class StdioClientTransport( - private val input: Source, - private val output: Sink -) : AbstractTransport() { +public class StdioClientTransport(private val input: Source, private val output: Sink) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val ioCoroutineContext: CoroutineContext = IODispatcher private val scope by lazy { diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 7b365638..2f524475 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -50,10 +50,8 @@ private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" /** * Error class for Streamable HTTP transport errors. */ -public class StreamableHttpError( - public val code: Int? = null, - message: String? = null -) : Exception("Streamable HTTP error: $message") +public class StreamableHttpError(public val code: Int? = null, message: String? = null) : + Exception("Streamable HTTP error: $message") /** * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. @@ -102,15 +100,16 @@ public class StreamableHttpClientTransport( public suspend fun send( message: JSONRPCMessage, resumptionToken: String?, - onResumptionToken: ((String) -> Unit)? = null + onResumptionToken: ((String) -> Unit)? = null, ) { logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } // If we have a resumption token, reconnect the SSE stream with it resumptionToken?.let { token -> startSseSession( - resumptionToken = token, onResumptionToken = onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null + resumptionToken = token, + onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null, ) return } @@ -147,8 +146,9 @@ public class StreamableHttpClientTransport( } ContentType.Text.EventStream -> handleInlineSse( - response, onResumptionToken = onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null + response, + onResumptionToken = onResumptionToken, + replayMessageId = if (message is JSONRPCRequest) message.id else null, ) else -> { @@ -197,7 +197,7 @@ public class StreamableHttpClientTransport( if (!response.status.isSuccess() && response.status != HttpStatusCode.MethodNotAllowed) { val error = StreamableHttpError( response.status.value, - "Failed to terminate session: ${response.status.description}" + "Failed to terminate session: ${response.status.description}", ) logger.error(error) { "Failed to terminate session" } _onError(error) @@ -212,7 +212,7 @@ public class StreamableHttpClientTransport( private suspend fun startSseSession( resumptionToken: String? = null, replayMessageId: RequestId? = null, - onResumptionToken: ((String) -> Unit)? = null + onResumptionToken: ((String) -> Unit)? = null, ) { sseSession?.cancel() sseJob?.cancelAndJoin() @@ -254,7 +254,7 @@ public class StreamableHttpClientTransport( private suspend fun collectSse( session: ClientSSESession, replayMessageId: RequestId?, - onResumptionToken: ((String) -> Unit)? + onResumptionToken: ((String) -> Unit)?, ) { try { session.incoming.collect { event -> @@ -290,7 +290,7 @@ public class StreamableHttpClientTransport( private suspend fun handleInlineSse( response: HttpResponse, replayMessageId: RequestId?, - onResumptionToken: ((String) -> Unit)? + onResumptionToken: ((String) -> Unit)?, ) { logger.trace { "Handling inline SSE from POST response" } val channel = response.bodyAsChannel() diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt index c2454e1f..1a600a3a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -39,4 +39,4 @@ public suspend fun HttpClient.mcpStreamableHttp( val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) client.connect(transport) return client -} \ No newline at end of file +} diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt index 9d70d6c0..77062ab1 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/WebSocketMcpKtorClientExtensions.kt @@ -33,8 +33,8 @@ public suspend fun HttpClient.mcpWebSocket( val client = Client( Implementation( name = IMPLEMENTATION_NAME, - version = LIB_VERSION - ) + version = LIB_VERSION, + ), ) client.connect(transport) return client diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt index c286eaab..12d20905 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransportTest.kt @@ -47,7 +47,7 @@ class StreamableHttpClientTransportTest { val message = JSONRPCRequest( id = RequestId.StringId("test-id"), method = "test", - params = buildJsonObject { } + params = buildJsonObject { }, ) val transport = createTransport { request -> @@ -61,7 +61,7 @@ class StreamableHttpClientTransportTest { respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } @@ -76,26 +76,30 @@ class StreamableHttpClientTransportTest { id = RequestId.StringId("test-id"), method = "initialize", params = buildJsonObject { - put("clientInfo", buildJsonObject { - put("name", JsonPrimitive("test-client")) - put("version", JsonPrimitive("1.0")) - }) + put( + "clientInfo", + buildJsonObject { + put("name", JsonPrimitive("test-client")) + put("version", JsonPrimitive("1.0")) + }, + ) put("protocolVersion", JsonPrimitive("2025-06-18")) - } + }, ) val transport = createTransport { request -> when (val msg = McpJson.decodeFromString((request.body as TextContent).text)) { is JSONRPCRequest if msg.method == "initialize" -> respond( - content = "", status = HttpStatusCode.OK, - headers = headersOf("mcp-session-id", "test-session-id") + content = "", + status = HttpStatusCode.OK, + headers = headersOf("mcp-session-id", "test-session-id"), ) is JSONRPCNotification if msg.method == "test" -> { assertEquals("test-session-id", request.headers["mcp-session-id"]) respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } @@ -122,7 +126,7 @@ class StreamableHttpClientTransportTest { assertEquals("test-session-id", request.headers["mcp-session-id"]) respond( content = "", - status = HttpStatusCode.OK + status = HttpStatusCode.OK, ) } @@ -141,7 +145,7 @@ class StreamableHttpClientTransportTest { assertEquals(HttpMethod.Delete, request.method) respond( content = "", - status = HttpStatusCode.MethodNotAllowed + status = HttpStatusCode.MethodNotAllowed, ) } @@ -160,7 +164,7 @@ class StreamableHttpClientTransportTest { assertEquals("2025-06-18", request.headers["mcp-protocol-version"]) respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } transport.protocolVersion = "2025-06-18" @@ -170,7 +174,8 @@ class StreamableHttpClientTransportTest { transport.close() } - @Ignore //Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support + // Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support + @Ignore @Test fun testNotificationSchemaE2E() = runTest { val receivedMessages = mutableListOf() @@ -182,7 +187,7 @@ class StreamableHttpClientTransportTest { respond( content = "", status = HttpStatusCode.Accepted, - headers = headersOf("mcp-session-id", "notification-test-session") + headers = headersOf("mcp-session-id", "notification-test-session"), ) } @@ -193,7 +198,9 @@ class StreamableHttpClientTransportTest { // Server sends various notifications appendLine("event: message") appendLine("id: 1") - appendLine("""data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/progress","params":{"progressToken":"upload-123","progress":50,"total":100}}""", + ) appendLine() appendLine("event: message") @@ -210,8 +217,9 @@ class StreamableHttpClientTransportTest { content = ByteReadChannel(sseContent), status = HttpStatusCode.OK, headers = headersOf( - HttpHeaders.ContentType, ContentType.Text.EventStream.toString() - ) + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), ) } @@ -219,7 +227,7 @@ class StreamableHttpClientTransportTest { HttpMethod.Post -> { respond( content = "", - status = HttpStatusCode.Accepted + status = HttpStatusCode.Accepted, ) } @@ -238,11 +246,14 @@ class StreamableHttpClientTransportTest { method = "notifications/initialized", params = buildJsonObject { put("protocolVersion", JsonPrimitive("1.0")) - put("capabilities", buildJsonObject { - put("tools", JsonPrimitive(true)) - put("resources", JsonPrimitive(true)) - }) - } + put( + "capabilities", + buildJsonObject { + put("tools", JsonPrimitive(true)) + put("resources", JsonPrimitive(true)) + }, + ) + }, ) transport.send(initializedNotification) @@ -274,25 +285,28 @@ class StreamableHttpClientTransportTest { params = buildJsonObject { put("progressToken", JsonPrimitive("download-456")) put("progress", JsonPrimitive(75)) - } + }, ), JSONRPCNotification( method = "notifications/cancelled", params = buildJsonObject { put("requestId", JsonPrimitive("req-789")) put("reason", JsonPrimitive("user_cancelled")) - } + }, ), JSONRPCNotification( method = "notifications/message", params = buildJsonObject { put("level", JsonPrimitive("info")) put("message", JsonPrimitive("Operation completed")) - put("data", buildJsonObject { - put("duration", JsonPrimitive(1234)) - }) - } - ) + put( + "data", + buildJsonObject { + put("duration", JsonPrimitive(1234)) + }, + ) + }, + ), ) // Send all client notifications @@ -305,7 +319,8 @@ class StreamableHttpClientTransportTest { transport.close() } - @Ignore // Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support + // Engine doesn't support SSECapability: https://youtrack.jetbrains.com/issue/KTOR-8177/MockEngine-Add-SSE-support + @Ignore @Test fun testNotificationWithResumptionToken() = runTest { var resumptionTokenReceived: String? = null @@ -320,15 +335,18 @@ class StreamableHttpClientTransportTest { val sseContent = buildString { appendLine("event: message") appendLine("id: resume-100") - appendLine("""data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"${lastEventIdSent}"}}""") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/resumed","params":{"fromToken":"$lastEventIdSent"}}""", + ) appendLine() } respond( content = ByteReadChannel(sseContent), status = HttpStatusCode.OK, headers = headersOf( - HttpHeaders.ContentType, ContentType.Text.EventStream.toString() - ) + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), ) } @@ -344,12 +362,12 @@ class StreamableHttpClientTransportTest { method = "notifications/test", params = buildJsonObject { put("data", JsonPrimitive("test-data")) - } + }, ), resumptionToken = "previous-token-99", onResumptionToken = { token -> resumptionTokenReceived = token - } + }, ) // Wait for response diff --git a/kotlin-sdk-core/build.gradle.kts b/kotlin-sdk-core/build.gradle.kts index ee8c3477..033bf55e 100644 --- a/kotlin-sdk-core/build.gradle.kts +++ b/kotlin-sdk-core/build.gradle.kts @@ -11,9 +11,15 @@ plugins { } kotlin { - iosArm64(); iosX64(); iosSimulatorArm64() - watchosX64(); watchosArm64(); watchosSimulatorArm64() - tvosX64(); tvosArm64(); tvosSimulatorArm64() + iosArm64() + iosX64() + iosSimulatorArm64() + watchosX64() + watchosArm64() + watchosSimulatorArm64() + tvosX64() + tvosArm64() + tvosSimulatorArm64() js { browser() nodejs() @@ -42,4 +48,4 @@ kotlin { } } } -} \ No newline at end of file +} diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index b5f15751..580c3982 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -109,7 +109,7 @@ public data class RequestOptions( /** * Extra data given to request handlers. */ -public class RequestHandlerExtra() +public class RequestHandlerExtra internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) } @@ -117,15 +117,20 @@ internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) } * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ -public abstract class Protocol( - @PublishedApi internal val options: ProtocolOptions?, -) { +public abstract class Protocol(@PublishedApi internal val options: ProtocolOptions?) { public var transport: Transport? = null private set - private val _requestHandlers: AtomicRef RequestResult?>> = + private val _requestHandlers: + AtomicRef RequestResult?>> = atomic(persistentMapOf()) - public val requestHandlers: Map RequestResult?> + public val requestHandlers: Map< + String, + suspend ( + request: JSONRPCRequest, + extra: RequestHandlerExtra, + ) -> RequestResult?, + > get() = _requestHandlers.value private val _notificationHandlers = @@ -133,7 +138,8 @@ public abstract class Protocol( public val notificationHandlers: Map Unit> get() = _notificationHandlers.value - private val _responseHandlers: AtomicRef Unit>> = + private val _responseHandlers: + AtomicRef Unit>> = atomic(persistentMapOf()) public val responseHandlers: Map Unit> get() = _responseHandlers.value @@ -161,7 +167,9 @@ public abstract class Protocol( /** * A handler to invoke for any request types that do not have their own handler installed. */ - public var fallbackRequestHandler: (suspend (request: JSONRPCRequest, extra: RequestHandlerExtra) -> RequestResult?)? = + public var fallbackRequestHandler: ( + suspend (request: JSONRPCRequest, extra: RequestHandlerExtra) -> RequestResult? + )? = null /** @@ -251,8 +259,8 @@ public abstract class Protocol( error = JSONRPCError( ErrorCode.Defined.MethodNotFound, message = "Server does not support ${request.method}", - ) - ) + ), + ), ) } catch (cause: Throwable) { LOGGER.error(cause) { "Error sending method not found response" } @@ -268,10 +276,9 @@ public abstract class Protocol( transport?.send( JSONRPCResponse( id = request.id, - result = result - ) + result = result, + ), ) - } catch (cause: Throwable) { LOGGER.error(cause) { "Error handling request: ${request.method} (id: ${request.id})" } @@ -281,19 +288,23 @@ public abstract class Protocol( id = request.id, error = JSONRPCError( code = ErrorCode.Defined.InternalError, - message = cause.message ?: "Internal error" - ) - ) + message = cause.message ?: "Internal error", + ), + ), ) } catch (sendError: Throwable) { - LOGGER.error(sendError) { "Failed to send error response for request: ${request.method} (id: ${request.id})" } + LOGGER.error(sendError) { + "Failed to send error response for request: ${request.method} (id: ${request.id})" + } // Optionally implement fallback behavior here } } } private fun onProgress(notification: ProgressNotification) { - LOGGER.trace { "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" } + LOGGER.trace { + "Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}" + } val progress = notification.params.progress val total = notification.params.total val message = notification.params.message @@ -378,10 +389,7 @@ public abstract class Protocol( * * Do not use this method to emit notifications! Use notification() instead. */ - public suspend fun request( - request: Request, - options: RequestOptions? = null, - ): T { + public suspend fun request(request: Request, options: RequestOptions? = null): T { LOGGER.trace { "Sending request: ${request.method}" } val result = CompletableDeferred() val transport = this@Protocol.transport ?: throw Error("Not connected") @@ -427,14 +435,14 @@ public abstract class Protocol( val notification = CancelledNotification( params = CancelledNotification.Params( - requestId = messageId, - reason = reason.message ?: "Unknown" - ) + requestId = messageId, + reason = reason.message ?: "Unknown", + ), ) val serialized = JSONRPCNotification( notification.method.value, - params = McpJson.encodeToJsonElement(notification) + params = McpJson.encodeToJsonElement(notification), ) transport.send(serialized) @@ -454,7 +462,7 @@ public abstract class Protocol( McpError( ErrorCode.Defined.RequestTimeout.code, "Request timed out", - JsonObject(mutableMapOf("timeout" to JsonPrimitive(timeout.inWholeMilliseconds))) + JsonObject(mutableMapOf("timeout" to JsonPrimitive(timeout.inWholeMilliseconds))), ), ) result.cancel(cause) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt index c235e65b..10d91c14 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt @@ -20,6 +20,7 @@ public class ReadBuffer { var lfIndex = buffer.indexOf('\n'.code.toByte()) val line = when (lfIndex) { -1L -> return null + 0L -> { buffer.skip(1) return null @@ -44,11 +45,6 @@ public class ReadBuffer { } } -internal fun deserializeMessage(line: String): JSONRPCMessage { - return McpJson.decodeFromString(line) -} - -public fun serializeMessage(message: JSONRPCMessage): String { - return McpJson.encodeToString(message) + "\n" -} +internal fun deserializeMessage(line: String): JSONRPCMessage = McpJson.decodeFromString(line) +public fun serializeMessage(message: JSONRPCMessage): String = McpJson.encodeToString(message) + "\n" diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index 29e7b866..4a936768 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -43,7 +43,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { error( "WebSocketClientTransport already started! " + - "If using Client class, note that connect() calls start() automatically.", + "If using Client class, note that connect() calls start() automatically.", ) } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt index 8918f5c3..00aa8048 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.kt @@ -168,12 +168,10 @@ public sealed interface Notification { * * @return The JSON-RPC notification representation. */ -public fun Notification.toJSON(): JSONRPCNotification { - return JSONRPCNotification( - method = method.value, - params = McpJson.encodeToJsonElement(params), - ) -} +public fun Notification.toJSON(): JSONRPCNotification = JSONRPCNotification( + method = method.value, + params = McpJson.encodeToJsonElement(params), +) /** * Decodes a JSON-RPC notification into a protocol-specific [Notification]. @@ -200,9 +198,9 @@ public sealed interface RequestResult : WithMeta * @param _meta Additional metadata for the response. Defaults to an empty JSON object. */ @Serializable -public data class EmptyRequestResult( - override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, ClientResult +public data class EmptyRequestResult(override val _meta: JsonObject = EmptyJsonObject) : + ServerResult, + ClientResult /** * A uniquely identifying ID for a request in JSON-RPC. @@ -282,7 +280,6 @@ public sealed interface ErrorCode { MethodNotFound(-32601), InvalidParams(-32602), InternalError(-32603), - ; } @Serializable @@ -293,11 +290,8 @@ public sealed interface ErrorCode { * A response to a request that indicates an error occurred. */ @Serializable -public data class JSONRPCError( - val code: ErrorCode, - val message: String, - val data: JsonObject = EmptyJsonObject, -) : JSONRPCMessage +public data class JSONRPCError(val code: ErrorCode, val message: String, val data: JsonObject = EmptyJsonObject) : + JSONRPCMessage /** * Base interface for notification parameters with optional metadata. @@ -316,9 +310,9 @@ public sealed interface NotificationParams : WithMeta * A client MUST NOT attempt to cancel its `initialize` request. */ @Serializable -public data class CancelledNotification( - override val params: Params, -) : ClientNotification, ServerNotification { +public data class CancelledNotification(override val params: Params) : + ClientNotification, + ServerNotification { override val method: Method = Method.Defined.NotificationsCancelled @Serializable @@ -342,10 +336,7 @@ public data class CancelledNotification( * Describes the name and version of an MCP implementation. */ @Serializable -public data class Implementation( - val name: String, - val version: String, -) +public data class Implementation(val name: String, val version: String) /** * Capabilities a client may support. @@ -383,7 +374,7 @@ public data class ClientCapabilities( /** * Represents a request sent by the client. */ -//@Serializable(with = ClientRequestPolymorphicSerializer::class) +// @Serializable(with = ClientRequestPolymorphicSerializer::class) public interface ClientRequest : Request /** @@ -401,7 +392,7 @@ public sealed interface ClientResult : RequestResult /** * Represents a request sent by the server. */ -//@Serializable(with = ServerRequestPolymorphicSerializer::class) +// @Serializable(with = ServerRequestPolymorphicSerializer::class) public sealed interface ServerRequest : Request /** @@ -423,8 +414,12 @@ public sealed interface ServerResult : RequestResult */ @Serializable public data class UnknownMethodRequestOrNotification( - override val method: Method, override val params: NotificationParams? = null, -) : ClientNotification, ClientRequest, ServerNotification, ServerRequest + override val method: Method, + override val params: NotificationParams? = null, +) : ClientNotification, + ClientRequest, + ServerNotification, + ServerRequest /** * This request is sent from the client to the server when it first connects, asking it to begin initialization. @@ -435,7 +430,8 @@ public data class InitializeRequest( val capabilities: ClientCapabilities, val clientInfo: Implementation, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.Initialize } @@ -521,15 +517,11 @@ public data class InitializeResult( * This notification is sent from the client to the server after initialization has finished. */ @Serializable -public data class InitializedNotification( - override val params: Params = Params(), -) : ClientNotification { +public data class InitializedNotification(override val params: Params = Params()) : ClientNotification { override val method: Method = Method.Defined.NotificationsInitialized @Serializable - public data class Params( - override val _meta: JsonObject = EmptyJsonObject, - ) : NotificationParams + public data class Params(override val _meta: JsonObject = EmptyJsonObject) : NotificationParams } /* Ping */ @@ -538,7 +530,9 @@ public data class InitializedNotification( * The receiver must promptly respond, or else it may be disconnected. */ @Serializable -public class PingRequest : ServerRequest, ClientRequest { +public class PingRequest : + ServerRequest, + ClientRequest { override val method: Method = Method.Defined.Ping } @@ -592,9 +586,9 @@ public open class Progress( * An out-of-band notification used to inform the receiver of a progress update for a long-running request. */ @Serializable -public data class ProgressNotification( - override val params: Params, -) : ClientNotification, ServerNotification { +public data class ProgressNotification(override val params: Params) : + ClientNotification, + ServerNotification { override val method: Method = Method.Defined.NotificationsProgress @Serializable @@ -618,7 +612,8 @@ public data class ProgressNotification( */ override val message: String? = null, override val _meta: JsonObject = EmptyJsonObject, - ) : NotificationParams, ProgressBase + ) : NotificationParams, + ProgressBase } /* Pagination */ @@ -626,7 +621,9 @@ public data class ProgressNotification( * Represents a request supporting pagination. */ @Serializable -public sealed interface PaginatedRequest : Request, WithMeta { +public sealed interface PaginatedRequest : + Request, + WithMeta { /** * The cursor indicating the pagination position. */ @@ -669,11 +666,8 @@ public sealed interface ResourceContents { * @property text The text of the item. This must only be set if the item can actually be represented as text (not binary data). */ @Serializable -public data class TextResourceContents( - val text: String, - override val uri: String, - override val mimeType: String?, -) : ResourceContents +public data class TextResourceContents(val text: String, override val uri: String, override val mimeType: String?) : + ResourceContents /** * Represents the binary contents of a resource encoded as a base64 string. @@ -681,20 +675,14 @@ public data class TextResourceContents( * @property blob A base64-encoded string representing the binary data of the item. */ @Serializable -public data class BlobResourceContents( - val blob: String, - override val uri: String, - override val mimeType: String?, -) : ResourceContents +public data class BlobResourceContents(val blob: String, override val uri: String, override val mimeType: String?) : + ResourceContents /** * Represents resource contents with unknown or unspecified data. */ @Serializable -public data class UnknownResourceContents( - override val uri: String, - override val mimeType: String?, -) : ResourceContents +public data class UnknownResourceContents(override val uri: String, override val mimeType: String?) : ResourceContents /** * A known resource that the server is capable of reading. @@ -758,8 +746,9 @@ public data class ResourceTemplate( @Serializable public data class ListResourcesRequest( override val cursor: Cursor? = null, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.ResourcesList } @@ -771,7 +760,8 @@ public class ListResourcesResult( public val resources: List, override val nextCursor: Cursor? = null, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * Sent from the client to request a list of resource templates the server has. @@ -779,8 +769,9 @@ public class ListResourcesResult( @Serializable public data class ListResourceTemplatesRequest( override val cursor: Cursor?, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.ResourcesTemplatesList } @@ -792,16 +783,16 @@ public class ListResourceTemplatesResult( public val resourceTemplates: List, override val nextCursor: Cursor? = null, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * Sent from the client to the server to read a specific resource URI. */ @Serializable -public data class ReadResourceRequest( - val uri: String, - override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +public data class ReadResourceRequest(val uri: String, override val _meta: JsonObject = EmptyJsonObject) : + ClientRequest, + WithMeta { override val method: Method = Method.Defined.ResourcesRead } @@ -820,15 +811,11 @@ public class ReadResourceResult( * Servers may issue this without any previous subscription from the client. */ @Serializable -public data class ResourceListChangedNotification( - override val params: Params = Params(), -) : ServerNotification { +public data class ResourceListChangedNotification(override val params: Params = Params()) : ServerNotification { override val method: Method = Method.Defined.NotificationsResourcesListChanged @Serializable - public data class Params( - override val _meta: JsonObject = EmptyJsonObject, - ) : NotificationParams + public data class Params(override val _meta: JsonObject = EmptyJsonObject) : NotificationParams } /** @@ -841,7 +828,8 @@ public data class SubscribeRequest( */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.ResourcesSubscribe } @@ -855,7 +843,8 @@ public data class UnsubscribeRequest( */ val uri: String, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.ResourcesUnsubscribe } @@ -863,9 +852,7 @@ public data class UnsubscribeRequest( * A notification from the server to the client, informing it that a resource has changed and may need to be read again. This should only be sent if the client previously sent a resources/subscribe request. */ @Serializable -public data class ResourceUpdatedNotification( - override val params: Params, -) : ServerNotification { +public data class ResourceUpdatedNotification(override val params: Params) : ServerNotification { override val method: Method = Method.Defined.NotificationsResourcesUpdated @Serializable @@ -923,8 +910,9 @@ public class Prompt( @Serializable public data class ListPromptsRequest( override val cursor: Cursor? = null, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.PromptsList } @@ -936,7 +924,8 @@ public class ListPromptsResult( public val prompts: List, override val nextCursor: Cursor? = null, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * Used by the client to get a prompt provided by the server. @@ -954,7 +943,8 @@ public data class GetPromptRequest( val arguments: Map?, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.PromptsGet } @@ -1033,22 +1023,17 @@ public data class AudioContent( } } - /** * Unknown content provided to or from an LLM. */ @Serializable -public data class UnknownContent( - override val type: String, -) : PromptMessageContentMultimodal +public data class UnknownContent(override val type: String) : PromptMessageContentMultimodal /** * The contents of a resource, embedded into a prompt or tool call result. */ @Serializable -public data class EmbeddedResource( - val resource: ResourceContents, -) : PromptMessageContent { +public data class EmbeddedResource(val resource: ResourceContents) : PromptMessageContent { override val type: String = TYPE public companion object { @@ -1062,17 +1047,15 @@ public data class EmbeddedResource( @Suppress("EnumEntryName") @Serializable public enum class Role { - user, assistant, + user, + assistant, } /** * Describes a message returned as part of a prompt. */ @Serializable -public data class PromptMessage( - val role: Role, - val content: PromptMessageContent, -) +public data class PromptMessage(val role: Role, val content: PromptMessageContent) /** * The server's response to a prompts/get request from the client. @@ -1092,15 +1075,11 @@ public class GetPromptResult( * Servers may issue this without any previous subscription from the client. */ @Serializable -public data class PromptListChangedNotification( - override val params: Params = Params(), -) : ServerNotification { +public data class PromptListChangedNotification(override val params: Params = Params()) : ServerNotification { override val method: Method = Method.Defined.NotificationsPromptsListChanged @Serializable - public data class Params( - override val _meta: JsonObject = EmptyJsonObject, - ) : NotificationParams + public data class Params(override val _meta: JsonObject = EmptyJsonObject) : NotificationParams } /* Tools */ @@ -1155,7 +1134,6 @@ public data class ToolAnnotations( val openWorldHint: Boolean? = true, ) - /** * Definition for a tool the client can call. */ @@ -1187,20 +1165,14 @@ public data class Tool( val annotations: ToolAnnotations?, ) { @Serializable - public data class Input( - val properties: JsonObject = EmptyJsonObject, - val required: List? = null, - ) { + public data class Input(val properties: JsonObject = EmptyJsonObject, val required: List? = null) { @OptIn(ExperimentalSerializationApi::class) @EncodeDefault val type: String = "object" } @Serializable - public data class Output( - val properties: JsonObject = EmptyJsonObject, - val required: List? = null, - ) { + public data class Output(val properties: JsonObject = EmptyJsonObject, val required: List? = null) { @OptIn(ExperimentalSerializationApi::class) @EncodeDefault val type: String = "object" @@ -1213,8 +1185,9 @@ public data class Tool( @Serializable public data class ListToolsRequest( override val cursor: Cursor? = null, - override val _meta: JsonObject = EmptyJsonObject -) : ClientRequest, PaginatedRequest { + override val _meta: JsonObject = EmptyJsonObject, +) : ClientRequest, + PaginatedRequest { override val method: Method = Method.Defined.ToolsList } @@ -1226,7 +1199,8 @@ public class ListToolsResult( public val tools: List, override val nextCursor: Cursor?, override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult, PaginatedResult +) : ServerResult, + PaginatedResult /** * The server's response to a tool call. @@ -1269,7 +1243,8 @@ public data class CallToolRequest( val name: String, val arguments: JsonObject = EmptyJsonObject, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.ToolsCall } @@ -1278,15 +1253,11 @@ public data class CallToolRequest( * Servers may issue this without any previous subscription from the client. */ @Serializable -public data class ToolListChangedNotification( - override val params: Params = Params(), -) : ServerNotification { +public data class ToolListChangedNotification(override val params: Params = Params()) : ServerNotification { override val method: Method = Method.Defined.NotificationsToolsListChanged @Serializable - public data class Params( - override val _meta: JsonObject = EmptyJsonObject, - ) : NotificationParams + public data class Params(override val _meta: JsonObject = EmptyJsonObject) : NotificationParams } /* Logging */ @@ -1304,7 +1275,6 @@ public enum class LoggingLevel { critical, alert, emergency, - ; } /** @@ -1313,9 +1283,7 @@ public enum class LoggingLevel { * the server MAY decide which messages to send automatically. */ @Serializable -public data class LoggingMessageNotification( - override val params: Params, -) : ServerNotification { +public data class LoggingMessageNotification(override val params: Params) : ServerNotification { override val method: Method = Method.Defined.NotificationsMessage @Serializable @@ -1345,7 +1313,8 @@ public data class LoggingMessageNotification( */ val level: LoggingLevel, override val _meta: JsonObject = EmptyJsonObject, - ) : ClientRequest, WithMeta { + ) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.LoggingSetLevel } } @@ -1404,10 +1373,7 @@ public class ModelPreferences( * Describes a message issued to or received from an LLM API. */ @Serializable -public data class SamplingMessage( - val role: Role, - val content: PromptMessageContentMultimodal, -) +public data class SamplingMessage(val role: Role, val content: PromptMessageContentMultimodal) /** * A request from the server to sample an LLM via the client. @@ -1441,7 +1407,8 @@ public data class CreateMessageRequest( */ val modelPreferences: ModelPreferences?, override val _meta: JsonObject = EmptyJsonObject, -) : ServerRequest, WithMeta { +) : ServerRequest, + WithMeta { override val method: Method = Method.Defined.SamplingCreateMessage @Serializable @@ -1536,9 +1503,7 @@ public data class PromptReference( * Identifies a prompt. */ @Serializable -public data class UnknownReference( - override val type: String, -) : Reference +public data class UnknownReference(override val type: String) : Reference /** * A request from the client to the server to ask for completion options. @@ -1551,7 +1516,8 @@ public data class CompleteRequest( */ val argument: Argument, override val _meta: JsonObject = EmptyJsonObject, -) : ClientRequest, WithMeta { +) : ClientRequest, + WithMeta { override val method: Method = Method.Defined.CompletionComplete @Serializable @@ -1571,10 +1537,8 @@ public data class CompleteRequest( * The server's response to a completion/complete request */ @Serializable -public data class CompleteResult( - val completion: Completion, - override val _meta: JsonObject = EmptyJsonObject, -) : ServerResult { +public data class CompleteResult(val completion: Completion, override val _meta: JsonObject = EmptyJsonObject) : + ServerResult { @Suppress("CanBeParameter") @Serializable public class Completion( @@ -1626,7 +1590,9 @@ public data class Root( * Sent from the server to request a list of root URIs from the client. */ @Serializable -public class ListRootsRequest(override val _meta: JsonObject = EmptyJsonObject) : ServerRequest, WithMeta { +public class ListRootsRequest(override val _meta: JsonObject = EmptyJsonObject) : + ServerRequest, + WithMeta { override val method: Method = Method.Defined.RootsList } @@ -1634,24 +1600,18 @@ public class ListRootsRequest(override val _meta: JsonObject = EmptyJsonObject) * The client's response to a roots/list request from the server. */ @Serializable -public class ListRootsResult( - public val roots: List, - override val _meta: JsonObject = EmptyJsonObject, -) : ClientResult +public class ListRootsResult(public val roots: List, override val _meta: JsonObject = EmptyJsonObject) : + ClientResult /** * A notification from the client to the server, informing it that the list of roots has changed. */ @Serializable -public data class RootsListChangedNotification( - override val params: Params = Params(), -) : ClientNotification { +public data class RootsListChangedNotification(override val params: Params = Params()) : ClientNotification { override val method: Method = Method.Defined.NotificationsRootsListChanged @Serializable - public data class Params( - override val _meta: JsonObject = EmptyJsonObject, - ) : NotificationParams + public data class Params(override val _meta: JsonObject = EmptyJsonObject) : NotificationParams } /** @@ -1662,7 +1622,8 @@ public data class CreateElicitationRequest( public val message: String, public val requestedSchema: RequestedSchema, override val _meta: JsonObject = EmptyJsonObject, -) : ServerRequest, WithMeta { +) : ServerRequest, + WithMeta { override val method: Method = Method.Defined.ElicitationCreate @Serializable @@ -1705,5 +1666,5 @@ public data class CreateElicitationResult( */ public class McpError(public val code: Int, message: String, public val data: JsonObject = EmptyJsonObject) : Exception() { - override val message: String = "MCP error ${code}: $message" + override val message: String = "MCP error $code: $message" } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt index 570b376a..eaa0018e 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types.util.kt @@ -73,38 +73,35 @@ internal object StopReasonSerializer : KSerializer { } internal object ReferencePolymorphicSerializer : JsonContentPolymorphicSerializer(Reference::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return when (element.jsonObject.getValue("type").jsonPrimitive.content) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.jsonObject.getValue("type").jsonPrimitive.content) { ResourceReference.TYPE -> ResourceReference.serializer() PromptReference.TYPE -> PromptReference.serializer() else -> UnknownReference.serializer() } - } } internal object PromptMessageContentPolymorphicSerializer : JsonContentPolymorphicSerializer(PromptMessageContent::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return when (element.jsonObject.getValue("type").jsonPrimitive.content) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.jsonObject.getValue("type").jsonPrimitive.content) { ImageContent.TYPE -> ImageContent.serializer() TextContent.TYPE -> TextContent.serializer() EmbeddedResource.TYPE -> EmbeddedResource.serializer() AudioContent.TYPE -> AudioContent.serializer() else -> UnknownContent.serializer() } - } } internal object PromptMessageContentMultimodalPolymorphicSerializer : JsonContentPolymorphicSerializer(PromptMessageContentMultimodal::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return when (element.jsonObject.getValue("type").jsonPrimitive.content) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.jsonObject.getValue("type").jsonPrimitive.content) { ImageContent.TYPE -> ImageContent.serializer() TextContent.TYPE -> TextContent.serializer() AudioContent.TYPE -> AudioContent.serializer() else -> UnknownContent.serializer() } - } } internal object ResourceContentsPolymorphicSerializer : @@ -125,54 +122,48 @@ internal fun selectRequestDeserializer(method: String): DeserializationStrategy< return CustomRequest.serializer() } -internal fun selectClientRequestDeserializer(method: String): DeserializationStrategy? { - return when (method) { - Method.Defined.Ping.value -> PingRequest.serializer() - Method.Defined.Initialize.value -> InitializeRequest.serializer() - Method.Defined.CompletionComplete.value -> CompleteRequest.serializer() - Method.Defined.LoggingSetLevel.value -> SetLevelRequest.serializer() - Method.Defined.PromptsGet.value -> GetPromptRequest.serializer() - Method.Defined.PromptsList.value -> ListPromptsRequest.serializer() - Method.Defined.ResourcesList.value -> ListResourcesRequest.serializer() - Method.Defined.ResourcesTemplatesList.value -> ListResourceTemplatesRequest.serializer() - Method.Defined.ResourcesRead.value -> ReadResourceRequest.serializer() - Method.Defined.ResourcesSubscribe.value -> SubscribeRequest.serializer() - Method.Defined.ResourcesUnsubscribe.value -> UnsubscribeRequest.serializer() - Method.Defined.ToolsCall.value -> CallToolRequest.serializer() - Method.Defined.ToolsList.value -> ListToolsRequest.serializer() - else -> null - } +internal fun selectClientRequestDeserializer(method: String): DeserializationStrategy? = when (method) { + Method.Defined.Ping.value -> PingRequest.serializer() + Method.Defined.Initialize.value -> InitializeRequest.serializer() + Method.Defined.CompletionComplete.value -> CompleteRequest.serializer() + Method.Defined.LoggingSetLevel.value -> SetLevelRequest.serializer() + Method.Defined.PromptsGet.value -> GetPromptRequest.serializer() + Method.Defined.PromptsList.value -> ListPromptsRequest.serializer() + Method.Defined.ResourcesList.value -> ListResourcesRequest.serializer() + Method.Defined.ResourcesTemplatesList.value -> ListResourceTemplatesRequest.serializer() + Method.Defined.ResourcesRead.value -> ReadResourceRequest.serializer() + Method.Defined.ResourcesSubscribe.value -> SubscribeRequest.serializer() + Method.Defined.ResourcesUnsubscribe.value -> UnsubscribeRequest.serializer() + Method.Defined.ToolsCall.value -> CallToolRequest.serializer() + Method.Defined.ToolsList.value -> ListToolsRequest.serializer() + else -> null } -private fun selectClientNotificationDeserializer(element: JsonElement): DeserializationStrategy? { - return when (element.jsonObject.getValue("method").jsonPrimitive.content) { +private fun selectClientNotificationDeserializer(element: JsonElement): DeserializationStrategy? = + when (element.jsonObject.getValue("method").jsonPrimitive.content) { Method.Defined.NotificationsCancelled.value -> CancelledNotification.serializer() Method.Defined.NotificationsProgress.value -> ProgressNotification.serializer() Method.Defined.NotificationsInitialized.value -> InitializedNotification.serializer() Method.Defined.NotificationsRootsListChanged.value -> RootsListChangedNotification.serializer() else -> null } -} internal object ClientNotificationPolymorphicSerializer : JsonContentPolymorphicSerializer(ClientNotification::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientNotificationDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientNotificationDeserializer(element) ?: UnknownMethodRequestOrNotification.serializer() - } } -internal fun selectServerRequestDeserializer(method: String): DeserializationStrategy? { - return when (method) { - Method.Defined.Ping.value -> PingRequest.serializer() - Method.Defined.SamplingCreateMessage.value -> CreateMessageRequest.serializer() - Method.Defined.RootsList.value -> ListRootsRequest.serializer() - else -> null - } +internal fun selectServerRequestDeserializer(method: String): DeserializationStrategy? = when (method) { + Method.Defined.Ping.value -> PingRequest.serializer() + Method.Defined.SamplingCreateMessage.value -> CreateMessageRequest.serializer() + Method.Defined.RootsList.value -> ListRootsRequest.serializer() + else -> null } -internal fun selectServerNotificationDeserializer(element: JsonElement): DeserializationStrategy? { - return when (element.jsonObject.getValue("method").jsonPrimitive.content) { +internal fun selectServerNotificationDeserializer(element: JsonElement): DeserializationStrategy? = + when (element.jsonObject.getValue("method").jsonPrimitive.content) { Method.Defined.NotificationsCancelled.value -> CancelledNotification.serializer() Method.Defined.NotificationsProgress.value -> ProgressNotification.serializer() Method.Defined.NotificationsMessage.value -> LoggingMessageNotification.serializer() @@ -182,23 +173,20 @@ internal fun selectServerNotificationDeserializer(element: JsonElement): Deseria Method.Defined.NotificationsPromptsListChanged.value -> PromptListChangedNotification.serializer() else -> null } -} internal object ServerNotificationPolymorphicSerializer : JsonContentPolymorphicSerializer(ServerNotification::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectServerNotificationDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectServerNotificationDeserializer(element) ?: UnknownMethodRequestOrNotification.serializer() - } } internal object NotificationPolymorphicSerializer : JsonContentPolymorphicSerializer(Notification::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientNotificationDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientNotificationDeserializer(element) ?: selectServerNotificationDeserializer(element) ?: UnknownMethodRequestOrNotification.serializer() - } } internal object RequestPolymorphicSerializer : @@ -243,27 +231,24 @@ private fun selectClientResultDeserializer(element: JsonElement): Deserializatio internal object ServerResultPolymorphicSerializer : JsonContentPolymorphicSerializer(ServerResult::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectServerResultDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectServerResultDeserializer(element) ?: EmptyRequestResult.serializer() - } } internal object ClientResultPolymorphicSerializer : JsonContentPolymorphicSerializer(ClientResult::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientResultDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientResultDeserializer(element) ?: EmptyRequestResult.serializer() - } } internal object RequestResultPolymorphicSerializer : JsonContentPolymorphicSerializer(RequestResult::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - return selectClientResultDeserializer(element) + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + selectClientResultDeserializer(element) ?: selectServerResultDeserializer(element) ?: EmptyRequestResult.serializer() - } } internal object JSONRPCMessagePolymorphicSerializer : @@ -315,7 +300,7 @@ public fun CallToolResult.Companion.ok(content: String, meta: JsonObject = Empty CallToolResult( content = listOf(TextContent(content)), isError = false, - _meta = meta + _meta = meta, ) /** @@ -325,5 +310,5 @@ public fun CallToolResult.Companion.error(content: String, meta: JsonObject = Em CallToolResult( content = listOf(TextContent(content)), isError = true, - _meta = meta - ) \ No newline at end of file + _meta = meta, + ) diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt index 5b2d8e2b..247388f7 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/AudioContentSerializationTest.kt @@ -17,7 +17,7 @@ class AudioContentSerializationTest { private val audioContent = AudioContent( data = "base64-encoded-audio-data", - mimeType = "audio/wav" + mimeType = "audio/wav", ) @Test @@ -30,4 +30,4 @@ class AudioContentSerializationTest { val content = McpJson.decodeFromString(audioContentJson) assertEquals(expected = audioContent, actual = content) } -} \ No newline at end of file +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/CallToolResultUtilsTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/CallToolResultUtilsTest.kt index 4d5ad2dd..d331b01e 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/CallToolResultUtilsTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/CallToolResultUtilsTest.kt @@ -61,4 +61,3 @@ class CallToolResultUtilsTest { assertEquals(meta, result._meta) } } - diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt index 0e1f704a..0664351f 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/ToolSerializationTest.kt @@ -55,30 +55,42 @@ class ToolSerializationTest { annotations = null, inputSchema = Tool.Input( properties = buildJsonObject { - put("location", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("The city and state, e.g. San Francisco, CA")) - }) + put( + "location", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("The city and state, e.g. San Francisco, CA")) + }, + ) }, - required = listOf("location") + required = listOf("location"), ), outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) //region Serialize @@ -120,23 +132,35 @@ class ToolSerializationTest { name = "get_weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) - val expectedJson = createWeatherToolJson(name = "get_weather", outputSchema = """ + val expectedJson = + createWeatherToolJson( + name = "get_weather", + outputSchema = """ { "type": "object", "properties": { @@ -155,7 +179,8 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val actualJson = McpJson.encodeToString(weatherTool) @@ -169,21 +194,30 @@ class ToolSerializationTest { title = "Get weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) val expectedJson = createWeatherToolJson( name = "get_weather", @@ -207,7 +241,8 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val actualJson = McpJson.encodeToString(weatherTool) @@ -245,7 +280,10 @@ class ToolSerializationTest { @Test fun `should deserialize get_weather tool with outputSchema optional property specified`() { - val toolJson = createWeatherToolJson(name = "get_weather", outputSchema = """ + val toolJson = + createWeatherToolJson( + name = "get_weather", + outputSchema = """ { "type": "object", "properties": { @@ -264,27 +302,37 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val expectedTool = createWeatherTool( name = "get_weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) val actualTool = McpJson.decodeFromString(toolJson) @@ -316,28 +364,38 @@ class ToolSerializationTest { }, "required": ["temperature", "conditions", "humidity"] } - """.trimIndent()) + """.trimIndent(), + ) val expectedTool = createWeatherTool( name = "get_weather", title = "Get weather", outputSchema = Tool.Output( properties = buildJsonObject { - put("temperature", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Temperature in celsius")) - }) - put("conditions", buildJsonObject { - put("type", JsonPrimitive("string")) - put("description", JsonPrimitive("Weather conditions description")) - }) - put("humidity", buildJsonObject { - put("type", JsonPrimitive("number")) - put("description", JsonPrimitive("Humidity percentage")) - }) + put( + "temperature", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Temperature in celsius")) + }, + ) + put( + "conditions", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Weather conditions description")) + }, + ) + put( + "humidity", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("Humidity percentage")) + }, + ) }, - required = listOf("temperature", "conditions", "humidity") - ) + required = listOf("temperature", "conditions", "humidity"), + ), ) val actualTool = McpJson.decodeFromString(toolJson) @@ -352,9 +410,8 @@ class ToolSerializationTest { private fun createWeatherToolJson( name: String = "get_weather", title: String? = null, - outputSchema: String? = null + outputSchema: String? = null, ): String { - val stringBuilder = StringBuilder() stringBuilder @@ -371,7 +428,8 @@ class ToolSerializationTest { .appendLine(",") .append(" \"description\": \"Get the current weather in a given location\"") .appendLine(",") - .append(""" + .append( + """ "inputSchema": { "type": "object", "properties": { @@ -382,46 +440,49 @@ class ToolSerializationTest { }, "required": ["location"] } - """.trimIndent()) + """.trimIndent(), + ) if (outputSchema != null) { stringBuilder .appendLine(",") - .append(""" + .append( + """ "outputSchema": $outputSchema - """.trimIndent()) + """.trimIndent(), + ) } stringBuilder .appendLine() .appendLine("}") - return stringBuilder.toString().trimIndent() } private fun createWeatherTool( name: String = "get_weather", title: String? = null, - outputSchema: Tool.Output? = null - ): Tool { - return Tool( - name = name, - title = title, - description = "Get the current weather in a given location", - annotations = null, - inputSchema = Tool.Input( - properties = buildJsonObject { - put("location", buildJsonObject { + outputSchema: Tool.Output? = null, + ): Tool = Tool( + name = name, + title = title, + description = "Get the current weather in a given location", + annotations = null, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "location", + buildJsonObject { put("type", JsonPrimitive("string")) put("description", JsonPrimitive("The city and state, e.g. San Francisco, CA")) - }) - }, - required = listOf("location") - ), - outputSchema = outputSchema - ) - } + }, + ) + }, + required = listOf("location"), + ), + outputSchema = outputSchema, + ) //endregion Private Methods } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesTest.kt index 819db328..7db08048 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesTest.kt @@ -104,7 +104,7 @@ class TypesTest { fun `should validate image content`() { val imageContent = ImageContent( data = "aGVsbG8=", // base64 encoded "hello" - mimeType = "image/png" + mimeType = "image/png", ) assertEquals("image", imageContent.type) @@ -116,7 +116,7 @@ class TypesTest { fun `should serialize and deserialize image content correctly`() { val imageContent = ImageContent( data = "dGVzdA==", // base64 encoded "test" - mimeType = "image/jpeg" + mimeType = "image/jpeg", ) val json = McpJson.encodeToString(imageContent) @@ -132,7 +132,7 @@ class TypesTest { fun `should validate audio content`() { val audioContent = AudioContent( data = "aGVsbG8=", // base64 encoded "hello" - mimeType = "audio/mp3" + mimeType = "audio/mp3", ) assertEquals("audio", audioContent.type) @@ -144,7 +144,7 @@ class TypesTest { fun `should serialize and deserialize audio content correctly`() { val audioContent = AudioContent( data = "YXVkaW8=", // base64 encoded "audio" - mimeType = "audio/wav" + mimeType = "audio/wav", ) val json = McpJson.encodeToString(audioContent) @@ -161,7 +161,7 @@ class TypesTest { val resource = TextResourceContents( text = "File contents", uri = "file:///path/to/file.txt", - mimeType = "text/plain" + mimeType = "text/plain", ) val embeddedResource = EmbeddedResource(resource = resource) @@ -174,7 +174,7 @@ class TypesTest { val resource = BlobResourceContents( blob = "YmluYXJ5ZGF0YQ==", uri = "file:///path/to/binary.dat", - mimeType = "application/octet-stream" + mimeType = "application/octet-stream", ) val embeddedResource = EmbeddedResource(resource = resource) @@ -206,7 +206,7 @@ class TypesTest { val textContent = TextContent(text = "Hello, assistant!") val promptMessage = PromptMessage( role = Role.user, - content = textContent + content = textContent, ) assertEquals(Role.user, promptMessage.role) @@ -219,12 +219,12 @@ class TypesTest { val resource = TextResourceContents( text = "Primary application entry point", uri = "file:///project/src/main.rs", - mimeType = "text/x-rust" + mimeType = "text/x-rust", ) val embeddedResource = EmbeddedResource(resource = resource) val promptMessage = PromptMessage( role = Role.assistant, - content = embeddedResource + content = embeddedResource, ) assertEquals(Role.assistant, promptMessage.role) @@ -240,11 +240,11 @@ class TypesTest { fun `should serialize and deserialize prompt message correctly`() { val imageContent = ImageContent( data = "aW1hZ2VkYXRh", // base64 encoded "imagedata" - mimeType = "image/png" + mimeType = "image/png", ) val promptMessage = PromptMessage( role = Role.assistant, - content = imageContent + content = imageContent, ) val json = McpJson.encodeToString(promptMessage) @@ -267,17 +267,17 @@ class TypesTest { resource = TextResourceContents( text = "fn main() {}", uri = "file:///project/src/main.rs", - mimeType = "text/x-rust" - ) + mimeType = "text/x-rust", + ), ), EmbeddedResource( resource = TextResourceContents( text = "pub mod lib;", uri = "file:///project/src/lib.rs", - mimeType = "text/x-rust" - ) - ) - ) + mimeType = "text/x-rust", + ), + ), + ), ) assertEquals(3, toolResult.content.size) @@ -300,9 +300,9 @@ class TypesTest { val toolResult = CallToolResult( content = listOf( TextContent(text = "Operation completed"), - ImageContent(data = "aW1hZ2U=", mimeType = "image/png") + ImageContent(data = "aW1hZ2U=", mimeType = "image/png"), ), - isError = false + isError = false, ) val json = McpJson.encodeToString(toolResult) @@ -319,7 +319,7 @@ class TypesTest { fun `should validate CompleteRequest with prompt reference`() { val request = CompleteRequest( ref = PromptReference(name = "greeting"), - argument = CompleteRequest.Argument(name = "name", value = "A") + argument = CompleteRequest.Argument(name = "name", value = "A"), ) assertEquals("completion/complete", request.method.value) @@ -334,7 +334,7 @@ class TypesTest { fun `should validate CompleteRequest with resource reference`() { val request = CompleteRequest( ref = ResourceReference(uri = "github://repos/{owner}/{repo}"), - argument = CompleteRequest.Argument(name = "repo", value = "t") + argument = CompleteRequest.Argument(name = "repo", value = "t"), ) assertEquals("completion/complete", request.method.value) @@ -349,7 +349,7 @@ class TypesTest { fun `should serialize and deserialize CompleteRequest correctly`() { val request = CompleteRequest( ref = PromptReference(name = "test"), - argument = CompleteRequest.Argument(name = "arg", value = "") + argument = CompleteRequest.Argument(name = "arg", value = ""), ) val json = McpJson.encodeToString(request) @@ -367,7 +367,7 @@ class TypesTest { fun `should validate CompleteRequest with complex URIs`() { val request = CompleteRequest( ref = ResourceReference(uri = "api://v1/{tenant}/{resource}/{id}"), - argument = CompleteRequest.Argument(name = "id", value = "123") + argument = CompleteRequest.Argument(name = "id", value = "123"), ) val resourceRef = request.ref as ResourceReference @@ -375,4 +375,4 @@ class TypesTest { assertEquals("id", request.argument.name) assertEquals("123", request.argument.value) } -} \ No newline at end of file +} diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesUtilTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesUtilTest.kt index 444e4180..345c3054 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesUtilTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/TypesUtilTest.kt @@ -249,4 +249,4 @@ class TypesUtilTest { assertEquals(false, result.isError) assertEquals(meta, result._meta) } -} \ No newline at end of file +} diff --git a/kotlin-sdk-server/build.gradle.kts b/kotlin-sdk-server/build.gradle.kts index d107debf..0148b784 100644 --- a/kotlin-sdk-server/build.gradle.kts +++ b/kotlin-sdk-server/build.gradle.kts @@ -25,4 +25,4 @@ kotlin { } } } -} \ No newline at end of file +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 260ef967..57bae05f 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -94,9 +94,7 @@ internal fun ServerSSESession.mcpSseTransport( return transport } -internal suspend fun RoutingContext.mcpPostEndpoint( - transports: ConcurrentMap, -) { +internal suspend fun RoutingContext.mcpPostEndpoint(transports: ConcurrentMap) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index 83defa2e..8b5c6be2 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -28,10 +28,8 @@ public typealias SSEServerTransport = SseServerTransport * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. */ @OptIn(ExperimentalAtomicApi::class) -public class SseServerTransport( - private val endpoint: String, - private val session: ServerSSESession, -) : AbstractTransport() { +public class SseServerTransport(private val endpoint: String, private val session: ServerSSESession) : + AbstractTransport() { private val initialized: AtomicBoolean = AtomicBoolean(false) @OptIn(ExperimentalUuidApi::class) @@ -44,13 +42,15 @@ public class SseServerTransport( */ override suspend fun start() { if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { - error("SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.") + error( + "SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.", + ) } // Send the endpoint event session.send( event = "endpoint", - data = "${endpoint.encodeURLPath()}?$SESSION_ID_PARAM=${sessionId}", + data = "${endpoint.encodeURLPath()}?$SESSION_ID_PARAM=$sessionId", ) try { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index f0655fd9..ac71b5fe 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -64,10 +64,8 @@ private val logger = KotlinLogging.logger {} * @property capabilities The capabilities this server supports. * @property enforceStrictCapabilities Whether to strictly enforce capabilities when interacting with clients. */ -public class ServerOptions( - public val capabilities: ServerCapabilities, - enforceStrictCapabilities: Boolean = true, -) : ProtocolOptions(enforceStrictCapabilities = enforceStrictCapabilities) +public class ServerOptions(public val capabilities: ServerCapabilities, enforceStrictCapabilities: Boolean = true) : + ProtocolOptions(enforceStrictCapabilities = enforceStrictCapabilities) /** * An MCP server on top of a pluggable transport. @@ -79,11 +77,11 @@ public class ServerOptions( * @param serverInfo Information about this server implementation (name, version). * @param options Configuration options for the server. */ -public open class Server( - private val serverInfo: Implementation, - options: ServerOptions, -) : Protocol(options) { +public open class Server(private val serverInfo: Implementation, options: ServerOptions) : Protocol(options) { + @Suppress("ktlint:standard:backing-property-naming") private var _onInitialized: (() -> Unit) = {} + + @Suppress("ktlint:standard:backing-property-naming") private var _onClose: () -> Unit = {} /** @@ -221,7 +219,7 @@ public open class Server( title: String? = null, outputSchema: Tool.Output? = null, toolAnnotations: ToolAnnotations? = null, - handler: suspend (CallToolRequest) -> CallToolResult + handler: suspend (CallToolRequest) -> CallToolResult, ) { val tool = Tool(name, title, description, inputSchema, outputSchema, toolAnnotations) addTool(tool, handler) @@ -325,7 +323,7 @@ public open class Server( name: String, description: String? = null, arguments: List? = null, - promptProvider: suspend (GetPromptRequest) -> GetPromptResult + promptProvider: suspend (GetPromptRequest) -> GetPromptResult, ) { val prompt = Prompt(name = name, description = description, arguments = arguments) addPrompt(prompt, promptProvider) @@ -416,7 +414,7 @@ public open class Server( name: String, description: String, mimeType: String = "text/html", - readHandler: suspend (ReadResourceRequest) -> ReadResourceResult + readHandler: suspend (ReadResourceRequest) -> ReadResourceResult, ) { if (capabilities.resources == null) { logger.error { "Failed to add resource '$name': Server does not support resources capability" } @@ -426,7 +424,7 @@ public open class Server( _resources.update { current -> current.put( uri, - RegisteredResource(Resource(uri, name, description, mimeType), readHandler) + RegisteredResource(Resource(uri, name, description, mimeType), readHandler), ) } } @@ -507,9 +505,7 @@ public open class Server( * @return The result of the ping request. * @throws IllegalStateException If for some reason the method is not supported or the connection is closed. */ - public suspend fun ping(): EmptyRequestResult { - return request(PingRequest()) - } + public suspend fun ping(): EmptyRequestResult = request(PingRequest()) /** * Creates a message using the server's sampling capability. @@ -521,7 +517,7 @@ public open class Server( */ public suspend fun createMessage( params: CreateMessageRequest, - options: RequestOptions? = null + options: RequestOptions? = null, ): CreateMessageResult { logger.debug { "Creating message with params: $params" } return request(params, options) @@ -537,7 +533,7 @@ public open class Server( */ public suspend fun listRoots( params: JsonObject = EmptyJsonObject, - options: RequestOptions? = null + options: RequestOptions? = null, ): ListRootsResult { logger.debug { "Listing roots with params: $params" } return request(ListRootsRequest(params), options) @@ -546,7 +542,7 @@ public open class Server( public suspend fun createElicitation( message: String, requestedSchema: RequestedSchema, - options: RequestOptions? = null + options: RequestOptions? = null, ): CreateElicitationResult { logger.debug { "Creating elicitation with message: $message" } return request(CreateElicitationRequest(message, requestedSchema), options) @@ -607,14 +603,16 @@ public open class Server( val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) { requestedVersion } else { - logger.warn { "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" } + logger.warn { + "Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" + } LATEST_PROTOCOL_VERSION } return InitializeResult( protocolVersion = protocolVersion, capabilities = capabilities, - serverInfo = serverInfo + serverInfo = serverInfo, ) } @@ -723,26 +721,34 @@ public open class Server( } "notifications/resources/updated", - "notifications/resources/list_changed" -> { + "notifications/resources/list_changed", + -> { if (capabilities.resources == null) { - throw IllegalStateException("Server does not support notifying about resources (required for ${method.value})") + throw IllegalStateException( + "Server does not support notifying about resources (required for ${method.value})", + ) } } "notifications/tools/list_changed" -> { if (capabilities.tools == null) { - throw IllegalStateException("Server does not support notifying of tool list changes (required for ${method.value})") + throw IllegalStateException( + "Server does not support notifying of tool list changes (required for ${method.value})", + ) } } "notifications/prompts/list_changed" -> { if (capabilities.prompts == null) { - throw IllegalStateException("Server does not support notifying of prompt list changes (required for ${method.value})") + throw IllegalStateException( + "Server does not support notifying of prompt list changes (required for ${method.value})", + ) } } "notifications/cancelled", - "notifications/progress" -> { + "notifications/progress", + -> { // Always allowed } } @@ -772,7 +778,8 @@ public open class Server( } "prompts/get", - "prompts/list" -> { + "prompts/list", + -> { if (capabilities.prompts == null) { throw IllegalStateException("Server does not support prompts (required for $method)") } @@ -780,14 +787,16 @@ public open class Server( "resources/list", "resources/templates/list", - "resources/read" -> { + "resources/read", + -> { if (capabilities.resources == null) { throw IllegalStateException("Server does not support resources (required for $method)") } } "tools/call", - "tools/list" -> { + "tools/list", + -> { if (capabilities.tools == null) { throw IllegalStateException("Server does not support tools (required for $method)") } @@ -806,10 +815,7 @@ public open class Server( * @property tool The tool definition. * @property handler A suspend function to handle the tool call requests. */ -public data class RegisteredTool( - val tool: Tool, - val handler: suspend (CallToolRequest) -> CallToolResult -) +public data class RegisteredTool(val tool: Tool, val handler: suspend (CallToolRequest) -> CallToolResult) /** * A wrapper class representing a registered prompt on the server. @@ -819,7 +825,7 @@ public data class RegisteredTool( */ public data class RegisteredPrompt( val prompt: Prompt, - val messageProvider: suspend (GetPromptRequest) -> GetPromptResult + val messageProvider: suspend (GetPromptRequest) -> GetPromptResult, ) /** @@ -830,5 +836,5 @@ public data class RegisteredPrompt( */ public data class RegisteredResource( val resource: Resource, - val readHandler: suspend (ReadResourceRequest) -> ReadResourceResult + val readHandler: suspend (ReadResourceRequest) -> ReadResourceResult, ) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index c515ddac..0fa18b8f 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -31,10 +31,7 @@ import kotlin.coroutines.CoroutineContext * Reads from System.in and writes to System.out. */ @OptIn(ExperimentalAtomicApi::class) -public class StdioServerTransport( - private val inputStream: Source, - outputStream: Sink -) : AbstractTransport() { +public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() { private val logger = KotlinLogging.logger {} private val readBuffer = ReadBuffer() diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt index 9301749b..a3d2fd34 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensions.kt @@ -14,10 +14,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ -public fun Route.mcpWebSocket( - options: ServerOptions? = null, - handler: suspend Server.() -> Unit = {}, -) { +public fun Route.mcpWebSocket(options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}) { webSocket { createMcpServer(this, options, handler) } @@ -30,11 +27,7 @@ public fun Route.mcpWebSocket( * @param options Optional server configuration settings for the MCP server. * @param handler A suspend function that defines the server's behavior. */ -public fun Route.mcpWebSocket( - path: String, - options: ServerOptions? = null, - handler: suspend Server.() -> Unit = {}, -) { +public fun Route.mcpWebSocket(path: String, options: ServerOptions? = null, handler: suspend Server.() -> Unit = {}) { webSocket(path) { createMcpServer(this, options, handler) } @@ -45,9 +38,7 @@ public fun Route.mcpWebSocket( * * @param handler A suspend function that defines the behavior of the transport layer. */ -public fun Route.mcpWebSocketTransport( - handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, -) { +public fun Route.mcpWebSocketTransport(handler: suspend WebSocketMcpServerTransport.() -> Unit = {}) { webSocket { val transport = createMcpTransport(this) transport.start() @@ -62,10 +53,7 @@ public fun Route.mcpWebSocketTransport( * @param path The URL path at which to register the WebSocket route. * @param handler A suspend function that defines the behavior of the transport layer. */ -public fun Route.mcpWebSocketTransport( - path: String, - handler: suspend WebSocketMcpServerTransport.() -> Unit = {}, -) { +public fun Route.mcpWebSocketTransport(path: String, handler: suspend WebSocketMcpServerTransport.() -> Unit = {}) { webSocket(path) { val transport = createMcpTransport(this) transport.start() @@ -74,7 +62,6 @@ public fun Route.mcpWebSocketTransport( } } - private suspend fun Route.createMcpServer( session: WebSocketServerSession, options: ServerOptions?, @@ -85,14 +72,14 @@ private suspend fun Route.createMcpServer( val server = Server( serverInfo = Implementation( name = IMPLEMENTATION_NAME, - version = LIB_VERSION + version = LIB_VERSION, ), options = options ?: ServerOptions( capabilities = ServerCapabilities( prompts = ServerCapabilities.Prompts(listChanged = null), resources = ServerCapabilities.Resources(subscribe = null, listChanged = null), tools = ServerCapabilities.Tools(listChanged = null), - ) + ), ), ) @@ -101,8 +88,5 @@ private suspend fun Route.createMcpServer( server.close() } -private fun createMcpTransport( - session: WebSocketServerSession, -): WebSocketMcpServerTransport { - return WebSocketMcpServerTransport(session) -} +private fun createMcpTransport(session: WebSocketServerSession): WebSocketMcpServerTransport = + WebSocketMcpServerTransport(session) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt index 45cb4df9..877fda58 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt @@ -10,9 +10,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport * * @property session The WebSocket server session used for communication. */ -public class WebSocketMcpServerTransport( - override val session: WebSocketServerSession, -) : WebSocketMcpTransport() { +public class WebSocketMcpServerTransport(override val session: WebSocketServerSession) : WebSocketMcpTransport() { override suspend fun initializeSession() { val subprotocol = session.call.request.headers[HttpHeaders.SecWebSocketProtocol] if (subprotocol != MCP_SUBPROTOCOL) { diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index 7a5cabbd..ca7e35b3 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -13,4 +13,4 @@ kotlin { } } } -} \ No newline at end of file +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index 26330ce1..0294620b 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -66,13 +66,13 @@ class ClientTest { capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) @@ -85,13 +85,13 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject - ) - ) + sampling = EmptyJsonObject, + ), + ), ) client.connect(clientTransport) @@ -100,7 +100,7 @@ class ClientTest { @Test fun `should initialize with supported older protocol version`() = runTest { - val OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1] + val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1] val clientTransport = object : AbstractTransport() { override suspend fun start() {} @@ -109,17 +109,17 @@ class ClientTest { check(message.method == Method.Defined.Initialize.value) val result = InitializeResult( - protocolVersion = OLD_VERSION, + protocolVersion = oldVersion, capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) } @@ -131,19 +131,19 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject - ) - ) + sampling = EmptyJsonObject, + ), + ), ) client.connect(clientTransport) assertEquals( Implementation("test", "1.0"), - client.serverVersion + client.serverVersion, ) } @@ -162,13 +162,13 @@ class ClientTest { capabilities = ServerCapabilities(), serverInfo = Implementation( name = "test", - version = "1.0" - ) + version = "1.0", + ), ) val response = JSONRPCResponse( id = message.id, - result = result + result = result, ) _onMessage.invoke(response) @@ -182,9 +182,9 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), - options = ClientOptions() + options = ClientOptions(), ) assertFailsWith("Server's protocol version is not supported: invalid-version") { @@ -214,9 +214,9 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), - options = ClientOptions() + options = ClientOptions(), ) val exception = assertFailsWith { @@ -233,12 +233,12 @@ class ClientTest { val serverOptions = ServerOptions( capabilities = ServerCapabilities( resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) server.setRequestHandler(Method.Defined.Initialize) { _, _ -> @@ -246,9 +246,9 @@ class ClientTest { protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) + tools = ServerCapabilities.Tools(null), ), - serverInfo = Implementation(name = "test", version = "1.0") + serverInfo = Implementation(name = "test", version = "1.0"), ) } @@ -266,7 +266,7 @@ class ClientTest { clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ) + ), ) listOf( @@ -275,7 +275,7 @@ class ClientTest { }, launch { server.connect(serverTransport) - } + }, ).joinAll() // Server supports resources and tools, but not prompts @@ -299,16 +299,16 @@ class ClientTest { fun `should respect client notification capabilities`() = runTest { val server = Server( Implementation(name = "test server", version = "1.0"), - ServerOptions(capabilities = ServerCapabilities()) + ServerOptions(capabilities = ServerCapabilities()), ) val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(listChanged = true) - ) - ) + roots = ClientCapabilities.Roots(listChanged = true), + ), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -321,7 +321,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() // This should not throw because the client supports roots.listChanged @@ -333,7 +333,7 @@ class ClientTest { options = ClientOptions( capabilities = ClientCapabilities(), // enforceStrictCapabilities = true // TODO() - ) + ), ) clientWithoutCapability.connect(clientTransport) @@ -354,16 +354,16 @@ class ClientTest { ServerOptions( capabilities = ServerCapabilities( logging = EmptyJsonObject, - resources = ServerCapabilities.Resources(listChanged = true, subscribe = null) - ) - ) + resources = ServerCapabilities.Resources(listChanged = true, subscribe = null), + ), + ), ) val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -376,7 +376,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() // These should not throw @@ -389,9 +389,9 @@ class ClientTest { LoggingMessageNotification( params = LoggingMessageNotification.Params( level = LoggingLevel.info, - data = jsonObject - ) - ) + data = jsonObject, + ), + ), ) server.sendResourceListChanged() @@ -410,10 +410,10 @@ class ClientTest { capabilities = ServerCapabilities( resources = ServerCapabilities.Resources( listChanged = null, - subscribe = null - ) - ) - ) + subscribe = null, + ), + ), + ), ) val def = CompletableDeferred() @@ -435,7 +435,7 @@ class ClientTest { val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions(capabilities = ClientCapabilities()) + options = ClientOptions(capabilities = ClientCapabilities()), ) listOf( @@ -446,7 +446,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() val defCancel = CompletableDeferred() @@ -472,10 +472,10 @@ class ClientTest { capabilities = ServerCapabilities( resources = ServerCapabilities.Resources( listChanged = null, - subscribe = null - ) - ) - ) + subscribe = null, + ), + ), + ), ) server.setRequestHandler(Method.Defined.ResourcesList) { _, _ -> @@ -495,7 +495,7 @@ class ClientTest { val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), - options = ClientOptions(capabilities = ClientCapabilities()) + options = ClientOptions(capabilities = ClientCapabilities()), ) listOf( @@ -506,7 +506,7 @@ class ClientTest { launch { server.connect(serverTransport) println("Server connected") - } + }, ).joinAll() // Request with 1 msec timeout should fail immediately @@ -523,13 +523,13 @@ class ClientTest { val client = Client( clientInfo = Implementation( name = "test client", - version = "1.0" + version = "1.0", ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject - ) - ) + sampling = EmptyJsonObject, + ), + ), ) client.setRequestHandler(Method.Defined.SamplingCreateMessage) { _, _ -> @@ -537,8 +537,8 @@ class ClientTest { model = "test-model", role = Role.assistant, content = TextContent( - text = "Test response" - ) + text = "Test response", + ), ) } @@ -551,12 +551,12 @@ class ClientTest { fun `JSONRPCRequest with ToolsList method and default params returns list of tools`() = runTest { val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) server.setRequestHandler(Method.Defined.Initialize) { _, _ -> @@ -564,9 +564,9 @@ class ClientTest { protocolVersion = LATEST_PROTOCOL_VERSION, capabilities = ServerCapabilities( resources = ServerCapabilities.Resources(null, null), - tools = ServerCapabilities.Tools(null) + tools = ServerCapabilities.Tools(null), ), - serverInfo = Implementation(name = "test", version = "1.0") + serverInfo = Implementation(name = "test", version = "1.0"), ) } val serverListToolsResult = ListToolsResult( @@ -577,9 +577,10 @@ class ClientTest { description = "testTool description", annotations = null, inputSchema = Tool.Input(), - outputSchema = null - ) - ), nextCursor = null + outputSchema = null, + ), + ), + nextCursor = null, ) server.setRequestHandler(Method.Defined.ToolsList) { _, _ -> @@ -592,7 +593,7 @@ class ClientTest { clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities(sampling = EmptyJsonObject), - ) + ), ) var receivedMessage: JSONRPCMessage? = null @@ -606,14 +607,14 @@ class ClientTest { }, launch { server.connect(serverTransport) - } + }, ).joinAll() val serverCapabilities = client.serverCapabilities assertEquals(ServerCapabilities.Tools(null), serverCapabilities?.tools) val request = JSONRPCRequest( - method = Method.Defined.ToolsList.value + method = Method.Defined.ToolsList.value, ) clientTransport.send(request) @@ -631,13 +632,13 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(null) - ) - ) + roots = ClientCapabilities.Roots(null), + ), + ), ) val clientRoots = listOf( - Root(uri = "file:///test-root", name = "testRoot") + Root(uri = "file:///test-root", name = "testRoot"), ) client.addRoots(clientRoots) @@ -647,13 +648,13 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() val clientCapabilities = server.clientCapabilities @@ -669,8 +670,8 @@ class ClientTest { val client = Client( Implementation(name = "test client", version = "1.0"), ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) // Verify that adding a root throws an exception @@ -685,8 +686,8 @@ class ClientTest { val client = Client( Implementation(name = "test client", version = "1.0"), ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) // Verify that removing a root throws an exception @@ -702,9 +703,9 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(null) - ) - ) + roots = ClientCapabilities.Roots(null), + ), + ), ) // Add some roots @@ -712,7 +713,7 @@ class ClientTest { listOf( Root(uri = "file:///test-root1", name = "testRoot1"), Root(uri = "file:///test-root2", name = "testRoot2"), - ) + ), ) // Remove a root @@ -728,9 +729,9 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(null) - ) - ) + roots = ClientCapabilities.Roots(null), + ), + ), ) // Add some roots @@ -738,12 +739,12 @@ class ClientTest { listOf( Root(uri = "file:///test-root1", name = "testRoot1"), Root(uri = "file:///test-root2", name = "testRoot2"), - ) + ), ) // Remove multiple roots val result = client.removeRoots( - listOf("file:///test-root1", "file:///test-root2") + listOf("file:///test-root1", "file:///test-root2"), ) // Verify the root was removed @@ -756,9 +757,9 @@ class ClientTest { Implementation(name = "test client", version = "1.0"), ClientOptions( capabilities = ClientCapabilities( - roots = ClientCapabilities.Roots(listChanged = true) - ) - ) + roots = ClientCapabilities.Roots(listChanged = true), + ), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -766,8 +767,8 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) // Track notifications @@ -779,14 +780,14 @@ class ClientTest { listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() client.sendRootsListChanged() assertTrue( rootListChangedNotificationReceived, - "Notification should be sent when sendRootsListChanged is called" + "Notification should be sent when sendRootsListChanged is called", ) } @@ -795,8 +796,8 @@ class ClientTest { val client = Client( Implementation(name = "test client", version = "1.0"), ClientOptions( - capabilities = ClientCapabilities() - ) + capabilities = ClientCapabilities(), + ), ) val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() @@ -804,13 +805,13 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() // Verify that creating an elicitation throws an exception @@ -823,13 +824,13 @@ class ClientTest { put("type", "string") } }, - required = listOf("name") - ) + required = listOf("name"), + ), ) } assertEquals( "Client does not support elicitation (required for elicitation/create)", - exception.message + exception.message, ) } @@ -840,8 +841,8 @@ class ClientTest { ClientOptions( capabilities = ClientCapabilities( elicitation = EmptyJsonObject, - ) - ) + ), + ), ) val elicitationMessage = "Please provide your GitHub username" @@ -851,7 +852,7 @@ class ClientTest { put("type", "string") } }, - required = listOf("name") + required = listOf("name"), ) val elicitationResultAction = CreateElicitationResult.Action.accept @@ -865,7 +866,7 @@ class ClientTest { CreateElicitationResult( action = elicitationResultAction, - content = elicitationResultContent + content = elicitationResultContent, ) } @@ -874,18 +875,18 @@ class ClientTest { val server = Server( serverInfo = Implementation(name = "test server", version = "1.0"), options = ServerOptions( - capabilities = ServerCapabilities() - ) + capabilities = ServerCapabilities(), + ), ) listOf( launch { client.connect(clientTransport) }, - launch { server.connect(serverTransport) } + launch { server.connect(serverTransport) }, ).joinAll() val result = server.createElicitation( message = elicitationMessage, - requestedSchema = requestedSchema + requestedSchema = requestedSchema, ) assertEquals(elicitationResultAction, result.action) diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index 23ddadf1..4b49e1e8 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -30,9 +30,9 @@ class SseTransportTest : BaseTransportTest() { mcpServer = Server( serverInfo = Implementation( name = "test-server", - version = "1.0" + version = "1.0", ), - options = ServerOptions(ServerCapabilities()) + options = ServerOptions(ServerCapabilities()), ) } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt index 24e5261e..3e7cdb2a 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt @@ -88,7 +88,7 @@ class InMemoryTransportTest { assertFailsWith { clientTransport.send( - InitializedNotification().toJSON() + InitializedNotification().toJSON(), ) } } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 72052518..2d1f25c2 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -43,9 +43,9 @@ class SseIntegrationTest { } } - private suspend fun initClient(port: Int): Client { - return HttpClient(ClientCIO) { install(ClientSSE) }.mcpSse("http://$URL:$port") - } + private suspend fun initClient(port: Int): Client = HttpClient(ClientCIO) { + install(ClientSSE) + }.mcpSse("http://$URL:$port") private suspend fun initServer(): EmbeddedServer { val server = Server( @@ -60,4 +60,4 @@ class SseIntegrationTest { } }.startSuspend(wait = false) } -} \ No newline at end of file +} diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index 6c69a135..acb2f278 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -33,7 +33,8 @@ abstract class BaseTransportTest { } val messages = listOf( - PingRequest().toJSON(), InitializedNotification().toJSON() + PingRequest().toJSON(), + InitializedNotification().toJSON(), ) val readMessages = mutableListOf() diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt index 86dc706a..fd567c75 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt @@ -44,4 +44,4 @@ class InMemoryTransport : AbstractTransport() { other._onMessage.invoke(message) } -} \ No newline at end of file +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt index cc3c58ff..562601aa 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt @@ -17,7 +17,7 @@ class ClientIntegrationTest { return StdioClientTransport( socket.inputStream.asSource().buffered(), - socket.outputStream.asSink().buffered() + socket.outputStream.asSink().buffered(), ) } @@ -34,10 +34,8 @@ class ClientIntegrationTest { val response: ListToolsResult = client.listTools() println(response.tools) - } finally { transport.close() } } - } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt index a8007a28..e8a163d3 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt @@ -19,7 +19,7 @@ class StdioClientTransportTest : BaseTransportTest() { val client = StdioClientTransport( input = input, - output = output + output = output, ) testClientOpenClose(client) @@ -37,7 +37,7 @@ class StdioClientTransportTest : BaseTransportTest() { val client = StdioClientTransport( input = input, - output = output + output = output, ) testClientRead(client) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt index bc6ae014..77175b47 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerTest.kt @@ -31,12 +31,12 @@ class ServerTest { // Create server with tools capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add a tool @@ -66,12 +66,12 @@ class ServerTest { // Create server with tools capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Setup client @@ -103,11 +103,11 @@ class ServerTest { fun `removeTool should throw when tools capability is not supported`() = runTest { // Create server without tools capability val serverOptions = ServerOptions( - capabilities = ServerCapabilities() + capabilities = ServerCapabilities(), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Verify that removing a tool throws an exception @@ -122,12 +122,12 @@ class ServerTest { // Create server with tools capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(null) - ) + tools = ServerCapabilities.Tools(null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add tools @@ -160,12 +160,12 @@ class ServerTest { // Create server with prompts capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = false) - ) + prompts = ServerCapabilities.Prompts(listChanged = false), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add a prompt @@ -173,7 +173,7 @@ class ServerTest { server.addPrompt(testPrompt) { GetPromptResult( description = "Test prompt description", - messages = listOf() + messages = listOf(), ) } @@ -199,12 +199,12 @@ class ServerTest { // Create server with prompts capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = false) - ) + prompts = ServerCapabilities.Prompts(listChanged = false), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add prompts @@ -213,13 +213,13 @@ class ServerTest { server.addPrompt(testPrompt1) { GetPromptResult( description = "Test prompt description 1", - messages = listOf() + messages = listOf(), ) } server.addPrompt(testPrompt2) { GetPromptResult( description = "Test prompt description 2", - messages = listOf() + messages = listOf(), ) } @@ -245,12 +245,12 @@ class ServerTest { // Create server with resources capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null) - ) + resources = ServerCapabilities.Resources(null, null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add a resource @@ -259,7 +259,7 @@ class ServerTest { uri = testResourceUri, name = "Test Resource", description = "A test resource", - mimeType = "text/plain" + mimeType = "text/plain", ) { ReadResourceResult( contents = listOf( @@ -268,7 +268,7 @@ class ServerTest { uri = testResourceUri, mimeType = "text/plain", ), - ) + ), ) } @@ -294,12 +294,12 @@ class ServerTest { // Create server with resources capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null) - ) + resources = ServerCapabilities.Resources(null, null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Add resources @@ -309,7 +309,7 @@ class ServerTest { uri = testResourceUri1, name = "Test Resource 1", description = "A test resource 1", - mimeType = "text/plain" + mimeType = "text/plain", ) { ReadResourceResult( contents = listOf( @@ -318,14 +318,14 @@ class ServerTest { uri = testResourceUri1, mimeType = "text/plain", ), - ) + ), ) } server.addResource( uri = testResourceUri2, name = "Test Resource 2", description = "A test resource 2", - mimeType = "text/plain" + mimeType = "text/plain", ) { ReadResourceResult( contents = listOf( @@ -334,7 +334,7 @@ class ServerTest { uri = testResourceUri2, mimeType = "text/plain", ), - ) + ), ) } @@ -360,12 +360,12 @@ class ServerTest { // Create server with prompts capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts(listChanged = false) - ) + prompts = ServerCapabilities.Prompts(listChanged = false), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Setup client @@ -397,11 +397,11 @@ class ServerTest { fun `removePrompt should throw when prompts capability is not supported`() = runTest { // Create server without prompts capability val serverOptions = ServerOptions( - capabilities = ServerCapabilities() + capabilities = ServerCapabilities(), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Verify that removing a prompt throws an exception @@ -416,12 +416,12 @@ class ServerTest { // Create server with resources capability val serverOptions = ServerOptions( capabilities = ServerCapabilities( - resources = ServerCapabilities.Resources(null, null) - ) + resources = ServerCapabilities.Resources(null, null), + ), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Setup client @@ -432,7 +432,9 @@ class ServerTest { // Track notifications var resourceListChangedNotificationReceived = false - client.setNotificationHandler(Method.Defined.NotificationsResourcesListChanged) { + client.setNotificationHandler( + Method.Defined.NotificationsResourcesListChanged, + ) { resourceListChangedNotificationReceived = true CompletableDeferred(Unit) } @@ -448,7 +450,7 @@ class ServerTest { assertFalse(result, "Removing non-existent resource should return false") assertFalse( resourceListChangedNotificationReceived, - "No notification should be sent when resource doesn't exist" + "No notification should be sent when resource doesn't exist", ) } @@ -456,11 +458,11 @@ class ServerTest { fun `removeResource should throw when resources capability is not supported`() = runTest { // Create server without resources capability val serverOptions = ServerOptions( - capabilities = ServerCapabilities() + capabilities = ServerCapabilities(), ) val server = Server( Implementation(name = "test server", version = "1.0"), - serverOptions + serverOptions, ) // Verify that removing a resource throws an exception