Skip to content

Commit cd26c91

Browse files
committed
add simple tests for StreamableClient and fix send
1 parent 5e4851d commit cd26c91

File tree

3 files changed

+240
-35
lines changed

3 files changed

+240
-35
lines changed

build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ kotlin {
268268

269269
jvmTest {
270270
dependencies {
271+
implementation(libs.ktor.client.mock)
271272
implementation(libs.mockk)
272273
implementation(libs.slf4j.simple)
273274
}

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -125,43 +125,39 @@ public class StreamableHttpClientTransport(
125125

126126
response.headers[MCP_SESSION_ID_HEADER]?.let { sessionId = it }
127127

128-
if (message is JSONRPCNotification || message is JSONRPCResponse) {
129-
if (response.status != HttpStatusCode.Accepted) {
130-
val text = response.bodyAsText()
131-
val err = StreamableHttpError(response.status.value, text)
132-
logger.error(err) { "Client POST request failed." }
133-
_onError(err)
134-
throw err
128+
if (response.status == HttpStatusCode.Accepted) {
129+
if (message is JSONRPCNotification && message.method == "notifications/initialized") {
130+
startSseSession(onResumptionToken = onResumptionToken)
135131
}
136132
return
137133
}
138134

139-
when {
140-
!response.status.isSuccess() -> {
141-
val text = response.bodyAsText()
142-
val err = StreamableHttpError(response.status.value, text)
143-
logger.error(err) { "Client POST request failed." }
144-
_onError(err)
145-
throw err
146-
}
147-
148-
response.contentType()?.match(ContentType.Application.Json) ?: false ->
149-
response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json ->
150-
runCatching { McpJson.decodeFromString<JSONRPCMessage>(json) }
151-
.onSuccess { _onMessage(it) }
152-
.onFailure(_onError)
153-
}
154-
155-
response.contentType()?.match(ContentType.Text.EventStream) ?: false ->
156-
handleInlineSse(
157-
response, onResumptionToken = onResumptionToken,
158-
replayMessageId = if (message is JSONRPCRequest) message.id else null
159-
)
135+
if (!response.status.isSuccess()) {
136+
val error = StreamableHttpError(response.status.value, response.bodyAsText())
137+
_onError(error)
138+
throw error
160139
}
161140

162-
// If client just sent InitializedNotification, open SSE stream
163-
if (message is JSONRPCNotification && message.method == "notifications/initialized" && sseSession == null) {
164-
startSseSession()
141+
when (response.contentType()?.withoutParameters()) {
142+
ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json ->
143+
runCatching { McpJson.decodeFromString<JSONRPCMessage>(json) }
144+
.onSuccess { _onMessage(it) }
145+
.onFailure(_onError)
146+
}
147+
148+
ContentType.Text.EventStream -> handleInlineSse(
149+
response, onResumptionToken = onResumptionToken,
150+
replayMessageId = if (message is JSONRPCRequest) message.id else null
151+
)
152+
else -> {
153+
val body = response.bodyAsText()
154+
if (response.contentType() == null && body.isBlank()) return
155+
156+
val ct = response.contentType()?.toString() ?: "<none>"
157+
val error = StreamableHttpError(-1, "Unexpected content type: $$ct")
158+
_onError(error)
159+
throw error
160+
}
165161
}
166162
}
167163

@@ -297,7 +293,6 @@ public class StreamableHttpClientTransport(
297293
) {
298294
logger.trace { "Handling inline SSE from POST response" }
299295
val channel = response.bodyAsChannel()
300-
val reader = channel
301296

302297
val sb = StringBuilder()
303298
var id: String? = null
@@ -325,16 +320,16 @@ public class StreamableHttpClientTransport(
325320
sb.clear()
326321
}
327322

328-
while (!reader.isClosedForRead) {
329-
val line = reader.readUTF8Line() ?: break
323+
while (!channel.isClosedForRead) {
324+
val line = channel.readUTF8Line() ?: break
330325
if (line.isEmpty()) {
331326
dispatch(sb.toString())
332327
continue
333328
}
334329
when {
335330
line.startsWith("id:") -> id = line.substringAfter("id:").trim()
336331
line.startsWith("event:") -> eventName = line.substringAfter("event:").trim()
337-
line.startsWith("data:") -> sb.appendLine(line.substringAfter("data:").trim())
332+
line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim())
338333
}
339334
}
340335
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
package client
2+
3+
import io.ktor.client.HttpClient
4+
import io.ktor.client.engine.mock.MockEngine
5+
import io.ktor.client.engine.mock.respond
6+
import io.ktor.client.plugins.sse.SSE
7+
import io.ktor.http.ContentType
8+
import io.ktor.http.HttpMethod
9+
import io.ktor.http.HttpStatusCode
10+
import io.ktor.http.content.TextContent
11+
import io.ktor.utils.io.ByteReadChannel
12+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
13+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification
14+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest
15+
import io.modelcontextprotocol.kotlin.sdk.RequestId
16+
import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport
17+
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
18+
import kotlinx.coroutines.test.runTest
19+
import kotlinx.serialization.json.buildJsonObject
20+
import org.junit.jupiter.api.assertDoesNotThrow
21+
import kotlin.test.AfterTest
22+
import kotlin.test.BeforeTest
23+
import kotlin.test.Test
24+
import kotlin.test.assertEquals
25+
import kotlin.test.assertNull
26+
import kotlin.time.Duration.Companion.seconds
27+
28+
class StreamableHttpClientTransportTest {
29+
private lateinit var mockEngine: MockEngine
30+
private lateinit var httpClient: HttpClient
31+
private lateinit var transport: StreamableHttpClientTransport
32+
33+
@BeforeTest
34+
fun setup() {
35+
mockEngine = MockEngine {
36+
respond(
37+
ByteReadChannel(""),
38+
status = HttpStatusCode.OK,
39+
)
40+
}
41+
42+
httpClient = HttpClient(mockEngine) {
43+
install(SSE) {
44+
reconnectionTime = 1.seconds
45+
}
46+
}
47+
48+
transport = StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp")
49+
}
50+
51+
@AfterTest
52+
fun teardown() {
53+
httpClient.close()
54+
}
55+
56+
@Test
57+
fun testSendJsonRpcMessage() = runTest {
58+
val message = JSONRPCRequest(
59+
id = RequestId.StringId("test-id"),
60+
method = "test",
61+
params = buildJsonObject { }
62+
)
63+
64+
mockEngine.config.addHandler { request ->
65+
assertEquals(HttpMethod.Post, request.method)
66+
assertEquals("http://localhost:8080/mcp", request.url.toString())
67+
assertEquals(ContentType.Application.Json, request.body.contentType)
68+
69+
val body = (request.body as TextContent).text
70+
val decodedMessage = McpJson.decodeFromString<JSONRPCMessage>(body)
71+
assertEquals(message, decodedMessage)
72+
73+
respond(
74+
content = "",
75+
status = HttpStatusCode.Accepted
76+
)
77+
}
78+
79+
transport.start()
80+
transport.send(message)
81+
}
82+
83+
// @Test
84+
// fun testStoreSessionId() = runTest {
85+
// val initMessage = JSONRPCRequest(
86+
// id = RequestId.StringId("test-id"),
87+
// method = "initialize",
88+
// params = buildJsonObject {
89+
// put("clientInfo", buildJsonObject {
90+
// put("name", JsonPrimitive("test-client"))
91+
// put("version", JsonPrimitive("1.0"))
92+
// })
93+
// put("protocolVersion", JsonPrimitive("2025-06-18"))
94+
// }
95+
// )
96+
//
97+
// mockEngine.config.addHandler { request ->
98+
// respond(
99+
// content = "", status = HttpStatusCode.OK,
100+
// headers = headersOf("mcp-session-id", "test-session-id")
101+
// )
102+
// }
103+
//
104+
// transport.start()
105+
// transport.send(initMessage)
106+
//
107+
// assertEquals("test-session-id", transport.sessionId)
108+
//
109+
// // Send another message and verify session ID is included
110+
// mockEngine.config.addHandler { request ->
111+
// assertEquals("test-session-id", request.headers["mcp-session-id"])
112+
// respond(
113+
// content = "",
114+
// status = HttpStatusCode.Accepted
115+
// )
116+
// }
117+
//
118+
// transport.send(JSONRPCNotification(method = "test"))
119+
// }
120+
121+
@Test
122+
fun testTerminateSession() = runTest {
123+
// transport.sessionId = "test-session-id"
124+
125+
mockEngine.config.addHandler { request ->
126+
assertEquals(HttpMethod.Delete, request.method)
127+
assertEquals("test-session-id", request.headers["mcp-session-id"])
128+
respond(
129+
content = "",
130+
status = HttpStatusCode.OK
131+
)
132+
}
133+
134+
transport.start()
135+
transport.terminateSession()
136+
137+
assertNull(transport.sessionId)
138+
}
139+
140+
@Test
141+
fun testTerminateSessionHandle405() = runTest {
142+
// transport.sessionId = "test-session-id"
143+
144+
mockEngine.config.addHandler { request ->
145+
assertEquals(HttpMethod.Delete, request.method)
146+
respond(
147+
content = "",
148+
status = HttpStatusCode.MethodNotAllowed
149+
)
150+
}
151+
152+
transport.start()
153+
// Should not throw for 405
154+
assertDoesNotThrow {
155+
transport.terminateSession()
156+
}
157+
158+
// Session ID should still be cleared
159+
assertNull(transport.sessionId)
160+
}
161+
162+
@Test
163+
fun testProtocolVersionHeader() = runTest {
164+
transport.protocolVersion = "2025-06-18"
165+
166+
mockEngine.config.addHandler { request ->
167+
assertEquals("2025-06-18", request.headers["mcp-protocol-version"])
168+
respond(
169+
content = "",
170+
status = HttpStatusCode.Accepted
171+
)
172+
}
173+
174+
transport.start()
175+
transport.send(JSONRPCNotification(method = "test"))
176+
}
177+
178+
@Test
179+
fun testHandle405ForSSE() = runTest {
180+
mockEngine.config.addHandler { request ->
181+
if (request.method == HttpMethod.Get) {
182+
respond(
183+
content = "",
184+
status = HttpStatusCode.MethodNotAllowed
185+
)
186+
} else {
187+
respond(
188+
content = "",
189+
status = HttpStatusCode.Accepted
190+
)
191+
}
192+
}
193+
194+
transport.start()
195+
196+
// Start SSE session - should handle 405 gracefully
197+
val initNotification = JSONRPCNotification(
198+
method = "notifications/initialized",
199+
)
200+
201+
// Should not throw
202+
assertDoesNotThrow {
203+
transport.send(initNotification)
204+
}
205+
206+
// Transport should still work after 405
207+
transport.send(JSONRPCNotification(method = "test"))
208+
}
209+
}

0 commit comments

Comments
 (0)