Skip to content

Commit 5301174

Browse files
authored
Merge branch 'main' into sean/progres-notification
2 parents 6843ee4 + 63288ea commit 5301174

File tree

3 files changed

+67
-4
lines changed

3 files changed

+67
-4
lines changed

.idea/misc.xml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import kotlinx.serialization.json.JsonElement
1010
import kotlinx.serialization.json.JsonNull
1111
import kotlinx.serialization.json.JsonObject
1212
import kotlinx.serialization.json.JsonPrimitive
13+
import kotlin.coroutines.cancellation.CancellationException
1314

1415
/**
1516
* Options for configuring the MCP client.
@@ -100,7 +101,12 @@ public open class Client(
100101
notification(InitializedNotification())
101102
} catch (error: Throwable) {
102103
close()
104+
if (error !is CancellationException) {
105+
throw IllegalStateException("Error connecting to transport: ${error.message}")
106+
}
107+
103108
throw error
109+
104110
}
105111
}
106112

src/jvmTest/kotlin/client/ClientTest.kt

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult
66
import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject
77
import io.modelcontextprotocol.kotlin.sdk.Implementation
88
import InMemoryTransport
9+
import io.mockk.coEvery
10+
import io.mockk.spyk
911
import io.modelcontextprotocol.kotlin.sdk.InitializeRequest
1012
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
1113
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
@@ -49,13 +51,13 @@ import kotlin.test.fail
4951
class ClientTest {
5052
@Test
5153
fun `should initialize with matching protocol version`() = runTest {
52-
var initialied = false
54+
var initialised = false
5355
val clientTransport = object : AbstractTransport() {
5456
override suspend fun start() {}
5557

5658
override suspend fun send(message: JSONRPCMessage) {
5759
if (message !is JSONRPCRequest) return
58-
initialied = true
60+
initialised = true
5961
val result = InitializeResult(
6062
protocolVersion = LATEST_PROTOCOL_VERSION,
6163
capabilities = ServerCapabilities(),
@@ -90,7 +92,7 @@ class ClientTest {
9092
)
9193

9294
client.connect(clientTransport)
93-
assertTrue(initialied)
95+
assertTrue(initialised)
9496
}
9597

9698
@Test
@@ -189,6 +191,61 @@ class ClientTest {
189191
assertTrue(closed)
190192
}
191193

194+
@Test
195+
fun `should reject due to non cancellation exception`() = runTest {
196+
var closed = false
197+
val clientTransport = object : AbstractTransport() {
198+
override suspend fun start() {}
199+
200+
override suspend fun send(message: JSONRPCMessage) {
201+
if (message !is JSONRPCRequest) return
202+
check(message.method == Method.Defined.Initialize.value)
203+
204+
val result = InitializeResult(
205+
protocolVersion = LATEST_PROTOCOL_VERSION,
206+
capabilities = ServerCapabilities(),
207+
serverInfo = Implementation(
208+
name = "test",
209+
version = "1.0"
210+
)
211+
)
212+
213+
val response = JSONRPCResponse(
214+
id = message.id,
215+
result = result
216+
)
217+
218+
_onMessage.invoke(response)
219+
}
220+
221+
override suspend fun close() {
222+
closed = true
223+
}
224+
}
225+
226+
val mockClient = spyk(
227+
Client(
228+
clientInfo = Implementation(
229+
name = "test client",
230+
version = "1.0"
231+
),
232+
options = ClientOptions()
233+
)
234+
)
235+
236+
coEvery{
237+
mockClient.request<InitializeResult>(any())
238+
} throws IllegalStateException("Test error")
239+
240+
val exception = assertFailsWith<IllegalStateException> {
241+
mockClient.connect(clientTransport)
242+
}
243+
244+
assertEquals("Error connecting to transport: Test error", exception.message)
245+
246+
assertTrue(closed)
247+
}
248+
192249
@Test
193250
fun `should respect server capabilities`() = runTest {
194251
val serverOptions = ServerOptions(

0 commit comments

Comments
 (0)