diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 56fd1caf..8eb8957a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -315,7 +315,7 @@ public open class Client(private val clientInfo: Implementation, options: Client * @throws IllegalStateException If the server does not support logging. */ public suspend fun setLoggingLevel(level: LoggingLevel, options: RequestOptions? = null): EmptyRequestResult = - request(SetLevelRequest(level), options) + request(SetLevelRequest(level), options) /** * Retrieves a prompt by name from the server. diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt index 20524ad8..887b57bd 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -16,6 +16,7 @@ import io.modelcontextprotocol.kotlin.sdk.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.LATEST_PROTOCOL_VERSION import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest import io.modelcontextprotocol.kotlin.sdk.ListRootsResult +import io.modelcontextprotocol.kotlin.sdk.LoggingLevel import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.Method import io.modelcontextprotocol.kotlin.sdk.Method.Defined @@ -27,6 +28,8 @@ import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic import kotlinx.coroutines.CompletableDeferred import kotlinx.serialization.json.JsonObject @@ -43,22 +46,6 @@ public open class ServerSession( @Suppress("ktlint:standard:backing-property-naming") private var _onClose: () -> Unit = {} - init { - // Core protocol handlers - setRequestHandler(Method.Defined.Initialize) { request, _ -> - handleInitialize(request) - } - setNotificationHandler(Method.Defined.NotificationsInitialized) { - _onInitialized() - CompletableDeferred(Unit) - } - } - - /** - * The capabilities supported by the server, related to the session. - */ - private val serverCapabilities = options.capabilities - /** * The client's reported capabilities after initialization. */ @@ -71,6 +58,37 @@ public open class ServerSession( public var clientVersion: Implementation? = null private set + /** + * The capabilities supported by the server, related to the session. + */ + private val serverCapabilities = options.capabilities + + /** + * The current logging level set by the client. + * When null, all messages are sent (no filtering). + */ + private val currentLoggingLevel: AtomicRef = atomic(null) + + init { + // Core protocol handlers + setRequestHandler(Defined.Initialize) { request, _ -> + handleInitialize(request) + } + setNotificationHandler(Defined.NotificationsInitialized) { + _onInitialized() + CompletableDeferred(Unit) + } + + // Logging level handler + if (options.capabilities.logging != null) { + setRequestHandler(Defined.LoggingSetLevel) { request, _ -> + currentLoggingLevel.value = request.level + logger.debug { "Logging level set to: ${request.level}" } + EmptyRequestResult() + } + } + } + /** * Registers a callback to be invoked when the server has completed initialization. */ @@ -160,12 +178,20 @@ public open class ServerSession( /** * Sends a logging message notification to the client. + * Messages are filtered based on the current logging level set by the client. + * If no logging level is set, all messages are sent. * * @param notification The logging message notification. */ public suspend fun sendLoggingMessage(notification: LoggingMessageNotification) { - logger.trace { "Sending logging message: ${notification.params.data}" } - notification(notification) + if (serverCapabilities.logging != null) { + if (isMessageAccepted(notification.params.level)) { + logger.trace { "Sending logging message: ${notification.params.data}" } + notification(notification) + } else { + logger.trace { "Filtering out logging message with level ${notification.params.level}" } + } + } } /** @@ -318,6 +344,7 @@ public open class ServerSession( Defined.LoggingSetLevel -> { if (serverCapabilities.logging == null) { + logger.error { "Server does not support logging (required for $method)" } throw IllegalStateException("Server does not support logging (required for $method)") } } @@ -381,4 +408,24 @@ public open class ServerSession( instructions = instructions, ) } + + /** + * Checks if a message with the given level should be ignored based on the current logging level. + * + * @param level The level of the message to check. + * @return true if the message should be ignored (filtered out), false otherwise. + */ + private fun isMessageIgnored(level: LoggingLevel): Boolean { + val current = currentLoggingLevel.value ?: return false // If no level is set, don't filter + + return level.ordinal < current.ordinal + } + + /** + * Checks if a message with the given level should be accepted based on the current logging level. + * + * @param level The level of the message to check. + * @return true if the message should be accepted (not filtered out), false otherwise. + */ + private fun isMessageAccepted(level: LoggingLevel): Boolean = !isMessageIgnored(level) } diff --git a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index b563ba05..ad2a9586 100644 --- a/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -886,6 +886,86 @@ class ClientTest { ) } + @Test + fun `should handle logging setLevel request`() = runTest { + val server = Server( + Implementation(name = "test server", version = "1.0"), + ServerOptions( + capabilities = ServerCapabilities( + logging = EmptyJsonObject, + ), + ), + ) + + val client = Client( + clientInfo = Implementation(name = "test client", version = "1.0"), + options = ClientOptions( + capabilities = ClientCapabilities(), + ), + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val receivedMessages = mutableListOf() + client.setNotificationHandler(Method.Defined.NotificationsMessage) { notification -> + receivedMessages.add(notification) + CompletableDeferred(Unit) + } + + val serverSessionResult = CompletableDeferred() + + listOf( + launch { + client.connect(clientTransport) + println("Client connected") + }, + launch { + serverSessionResult.complete(server.connect(serverTransport)) + println("Server connected") + }, + ).joinAll() + + val serverSession = serverSessionResult.await() + + // Set logging level to warning + val minLevel = LoggingLevel.warning + val result = client.setLoggingLevel(minLevel) + assertEquals(EmptyJsonObject, result._meta) + + // Send messages of different levels + val testMessages = listOf( + LoggingLevel.debug to "Debug - should be filtered", + LoggingLevel.info to "Info - should be filtered", + LoggingLevel.warning to "Warning - should pass", + LoggingLevel.error to "Error - should pass", + ) + + testMessages.forEach { (level, message) -> + serverSession.sendLoggingMessage( + LoggingMessageNotification( + params = LoggingMessageNotification.Params( + level = level, + data = buildJsonObject { put("message", message) }, + ), + ), + ) + } + + delay(100) + + // Only warning and error should be received + assertEquals(2, receivedMessages.size, "Should receive only 2 messages (warning and error)") + + // Verify all received messages have severity >= minLevel + receivedMessages.forEach { message -> + val messageSeverity = message.params.level.ordinal + assertTrue( + messageSeverity >= minLevel.ordinal, + "Received message with level ${message.params.level} should have severity >= $minLevel", + ) + } + } + @Test fun `should handle server elicitation`() = runTest { val client = Client(