diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt new file mode 100644 index 00000000..77ba553e --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/AbstractStreamableHttpClientTest.kt @@ -0,0 +1,35 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.ktor.client.HttpClient +import io.ktor.client.engine.apache5.Apache5 +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import io.ktor.client.plugins.sse.SSE +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.TestInstance + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +internal abstract class AbstractStreamableHttpClientTest { + + // start mokksy on random port + protected val mockMcp: MockMcp = MockMcp(verbose = true) + + @AfterEach + fun afterEach() { + mockMcp.checkForUnmatchedRequests() + } + + protected suspend fun connect(client: Client) { + client.connect( + StreamableHttpClientTransport( + url = mockMcp.url, + client = HttpClient(Apache5) { + install(SSE) + install(Logging) { + level = LogLevel.ALL + } + }, + ), + ) + } +} diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt index d6fcaed5..eef77f16 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import dev.mokksy.mokksy.BuildingStep import dev.mokksy.mokksy.Mokksy import dev.mokksy.mokksy.StubConfiguration import io.ktor.http.ContentType @@ -7,7 +8,18 @@ import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.sse.ServerSentEvent import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.RequestId import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.putJsonObject + +const val MCP_SESSION_ID_HEADER = "Mcp-Session-Id" /** * High-level helper for simulating an MCP server over Streaming HTTP transport with Server-Sent Events (SSE), @@ -26,51 +38,188 @@ internal class MockMcp(verbose: Boolean = false) { mokksy.checkForUnmatchedRequests() } - val url = mokksy.baseUrl() + "/mcp" + val url = "${mokksy.baseUrl()}/mcp" @Suppress("LongParameterList") + fun onInitialize( + clientName: String? = null, + sessionId: String, + protocolVersion: String = "2025-03-26", + serverName: String = "Mock MCP Server", + serverVersion: String = "1.0.0", + capabilities: JsonObject = buildJsonObject { + putJsonObject("tools") { + put("listChanged", JsonPrimitive(false)) + } + }, + ) { + val predicates = if (clientName != null) { + arrayOf<(JSONRPCRequest?) -> Boolean>({ + it?.params?.jsonObject + ?.get("clientInfo")?.jsonObject + ?.get("name")?.jsonPrimitive + ?.contentOrNull == clientName + }) + } else { + emptyArray() + } + + handleWithResult( + jsonRpcMethod = "initialize", + sessionId = sessionId, + bodyPredicates = predicates, + // language=json + result = """ + { + "capabilities": $capabilities, + "protocolVersion": "$protocolVersion", + "serverInfo": { + "name": "$serverName", + "version": "$serverVersion" + }, + "_meta": { + "foo": "bar" + } + } + """.trimIndent(), + ) + } + fun onJSONRPCRequest( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + vararg bodyPredicates: (JSONRPCRequest) -> Boolean, + ): BuildingStep = mokksy.method( + configuration = StubConfiguration(removeAfterMatch = true), + httpMethod = httpMethod, + requestType = JSONRPCRequest::class, + ) { + path("/mcp") + expectedSessionId?.let { + containsHeader(MCP_SESSION_ID_HEADER, it) + } + bodyMatchesPredicate( + description = "JSON-RPC version is '2.0'", + predicate = + { + it!!.jsonrpc == "2.0" + }, + ) + bodyMatchesPredicate( + description = "JSON-RPC Method should be '$jsonRpcMethod'", + predicate = + { + it!!.method == jsonRpcMethod + }, + ) + bodyPredicates.forEach { predicate -> + bodyMatchesPredicate(predicate = { predicate.invoke(it!!) }) + } + } + + @Suppress("LongParameterList") + fun handleWithResult( httpMethod: HttpMethod = HttpMethod.Post, jsonRpcMethod: String, expectedSessionId: String? = null, sessionId: String, contentType: ContentType = ContentType.Application.Json, statusCode: HttpStatusCode = HttpStatusCode.OK, - bodyBuilder: () -> String, + vararg bodyPredicates: (JSONRPCRequest) -> Boolean, + result: () -> JsonObject, ) { - mokksy.method( - configuration = StubConfiguration(removeAfterMatch = true), + onJSONRPCRequest( httpMethod = httpMethod, - requestType = JSONRPCRequest::class, - ) { - path("/mcp") - expectedSessionId?.let { - containsHeader("Mcp-Session-Id", it) + jsonRpcMethod = jsonRpcMethod, + expectedSessionId = expectedSessionId, + bodyPredicates = bodyPredicates, + ) respondsWith { + val requestId = when (request.body.id) { + is RequestId.NumberId -> (request.body.id as RequestId.NumberId).value.toString() + is RequestId.StringId -> "\"${(request.body.id as RequestId.StringId).value}\"" } - bodyMatchesPredicates( - { - it!!.method == jsonRpcMethod - }, - { - it!!.jsonrpc == "2.0" - }, - ) - } respondsWith { + val resultObject = result!!.invoke() + // language=json + body = """ + { + "jsonrpc": "2.0", + "id": $requestId, + "result": $resultObject + } + """.trimIndent() + this.contentType = contentType + headers += MCP_SESSION_ID_HEADER to sessionId + httpStatus = statusCode + } + } + + @Suppress("LongParameterList") + fun handleWithResult( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + sessionId: String, + contentType: ContentType = ContentType.Application.Json, + statusCode: HttpStatusCode = HttpStatusCode.OK, + vararg bodyPredicates: (JSONRPCRequest) -> Boolean, + result: String, + ) { + handleWithResult( + httpMethod = httpMethod, + jsonRpcMethod = jsonRpcMethod, + expectedSessionId = expectedSessionId, + sessionId = sessionId, + contentType = contentType, + statusCode = statusCode, + bodyPredicates = bodyPredicates, + result = { + Json.parseToJsonElement(result).jsonObject + }, + ) + } + + @Suppress("LongParameterList") + fun handleJSONRPCRequest( + httpMethod: HttpMethod = HttpMethod.Post, + jsonRpcMethod: String, + expectedSessionId: String? = null, + sessionId: String, + contentType: ContentType = ContentType.Application.Json, + statusCode: HttpStatusCode = HttpStatusCode.OK, + vararg bodyPredicates: (JSONRPCRequest?) -> Boolean, + bodyBuilder: () -> String = { "" }, + ) { + onJSONRPCRequest( + httpMethod = httpMethod, + jsonRpcMethod = jsonRpcMethod, + expectedSessionId = expectedSessionId, + bodyPredicates = bodyPredicates, + ) respondsWith { body = bodyBuilder.invoke() this.contentType = contentType - headers += "Mcp-Session-Id" to sessionId + headers += MCP_SESSION_ID_HEADER to sessionId httpStatus = statusCode } } - fun onSubscribeWithGet(sessionId: String, block: () -> Flow) { - mokksy.get(name = "MCP GETs", requestType = Any::class) { - path("/mcp") - containsHeader("Mcp-Session-Id", sessionId) - containsHeader("Accept", "application/json,text/event-stream") - containsHeader("Cache-Control", "no-store") - } respondsWithSseStream { - headers += "Mcp-Session-Id" to sessionId + fun onSubscribe(httpMethod: HttpMethod = HttpMethod.Post, sessionId: String): BuildingStep = mokksy.method( + httpMethod = httpMethod, + name = "MCP GETs", + requestType = Any::class, + ) { + path("/mcp") + containsHeader(MCP_SESSION_ID_HEADER, sessionId) + containsHeader("Accept", "application/json,text/event-stream") + containsHeader("Cache-Control", "no-store") + } + + fun handleSubscribeWithGet(sessionId: String, block: () -> Flow) { + onSubscribe( + httpMethod = HttpMethod.Get, + sessionId = sessionId, + ) respondsWithSseStream { + headers += MCP_SESSION_ID_HEADER to sessionId this.flow = block.invoke() } } @@ -81,7 +230,7 @@ internal class MockMcp(verbose: Boolean = false) { requestType = JSONRPCRequest::class, ) { path("/mcp") - containsHeader("Mcp-Session-Id", sessionId) + containsHeader(MCP_SESSION_ID_HEADER, sessionId) } respondsWith { body = null } diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt index 6e730c30..7cedfe75 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt @@ -1,11 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client import io.kotest.matchers.collections.shouldContain -import io.ktor.client.HttpClient -import io.ktor.client.engine.apache5.Apache5 -import io.ktor.client.plugins.logging.LogLevel -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.sse.SSE import io.ktor.http.HttpStatusCode import io.ktor.sse.ServerSentEvent import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities @@ -19,7 +14,6 @@ import kotlinx.serialization.json.put import kotlinx.serialization.json.putJsonObject import org.junit.jupiter.api.TestInstance import java.util.UUID -import kotlin.test.AfterTest import kotlin.test.Test import kotlin.time.Duration.Companion.milliseconds @@ -30,21 +24,16 @@ import kotlin.time.Duration.Companion.milliseconds * @author Konstantin Pavlov */ @TestInstance(TestInstance.Lifecycle.PER_CLASS) -class StreamableHttpClientTest { - - // start mokksy on random port - private val mockMcp: MockMcp = MockMcp(verbose = true) - - @AfterTest - fun afterEach() { - mockMcp.checkForUnmatchedRequests() - } +@Suppress("LongMethod") +internal class StreamableHttpClientTest : AbstractStreamableHttpClientTest() { @Test - @Suppress("LongMethod") - fun `test streamableHttpClient`(): Unit = runBlocking { + fun `test streamableHttpClient`() = runBlocking { val client = Client( - clientInfo = Implementation(name = "sample-client", version = "1.0.0"), + clientInfo = Implementation( + name = "client1", + version = "1.0.0", + ), options = ClientOptions( capabilities = ClientCapabilities(), ), @@ -52,44 +41,19 @@ class StreamableHttpClientTest { val sessionId = UUID.randomUUID().toString() - mockMcp.onJSONRPCRequest( - jsonRpcMethod = "initialize", + mockMcp.onInitialize( + clientName = "client1", sessionId = sessionId, - ) { - // language=json - """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "capabilities": { - "tools": { - "listChanged": false - } - }, - "protocolVersion": "2025-03-26", - "serverInfo": { - "name": "Mock MCP Server", - "version": "1.0.0" - }, - "_meta": { - "foo": "bar" - } - } - } - """.trimIndent() - } + ) - mockMcp.onJSONRPCRequest( + mockMcp.handleJSONRPCRequest( jsonRpcMethod = "notifications/initialized", expectedSessionId = sessionId, sessionId = sessionId, statusCode = HttpStatusCode.Accepted, - ) { - "" - } + ) - mockMcp.onSubscribeWithGet(sessionId) { + mockMcp.handleSubscribeWithGet(sessionId) { flow { delay(500.milliseconds) emit( @@ -112,30 +76,14 @@ class StreamableHttpClientTest { } } - client.connect( - StreamableHttpClientTransport( - url = mockMcp.url, - client = HttpClient(Apache5) { - install(SSE) - install(Logging) { - level = LogLevel.ALL - } - }, - ), - ) - // TODO: how to get notifications via Client API? - mockMcp.onJSONRPCRequest( + mockMcp.handleWithResult( jsonRpcMethod = "tools/list", sessionId = sessionId, - ) { // language=json - """ - { - "jsonrpc": "2.0", - "id": 3, - "result": { + result = """ + { "tools": [ { "name": "get_weather", @@ -164,9 +112,10 @@ class StreamableHttpClientTest { } ] } - } - """.trimIndent() - } + """.trimIndent(), + ) + + connect(client) val listToolsResult = client.listTools() diff --git a/kotlin-sdk-client/src/jvmTest/resources/simplelogger.properties b/kotlin-sdk-client/src/jvmTest/resources/simplelogger.properties index 506c4365..2336c133 100644 --- a/kotlin-sdk-client/src/jvmTest/resources/simplelogger.properties +++ b/kotlin-sdk-client/src/jvmTest/resources/simplelogger.properties @@ -1,14 +1,9 @@ # Level of logging for the ROOT logger: ERROR, WARN, INFO, DEBUG, TRACE (default is INFO) org.slf4j.simpleLogger.defaultLogLevel=INFO - - org.slf4j.simpleLogger.showThreadName=true org.slf4j.simpleLogger.showDateTime=false -# Whether to enable stack traces for exceptions (true/false, default is true) -org.slf4j.simpleLogger.showShortLogName=false - # Log level for specific packages or classes org.slf4j.simpleLogger.log.io.ktor.server=DEBUG -org.slf4j.simpleLogger.log.io.modelcontextprotocol=TRACE +org.slf4j.simpleLogger.log.io.modelcontextprotocol=DEBUG org.slf4j.simpleLogger.log.dev.mokksy=DEBUG