Skip to content

Commit 9ce7ee7

Browse files
committed
Refactor and enhance prompt handling logic, update Kotest version, and modularize AbstractTransport
- Extracted `AbstractTransport` to a dedicated file for modularity. - Improved prompt validation with `require` statement and updated related tests. - Downgraded Kotest dependency to 5.9.1 to ensure compatibility with JVM 1.8. - Fixed typos and improved test cases for better clarity. - Minor Kotlin fixes and suppressions to improve code quality.
1 parent 07f661a commit 9ce7ee7

File tree

8 files changed

+203
-79
lines changed

8 files changed

+203
-79
lines changed

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ kotlinx-io = "0.8.0"
1818
ktor = "3.2.3"
1919
logging = "7.0.13"
2020
slf4j = "2.0.17"
21-
kotest = "6.0.4"
21+
kotest = "5.9.1" # for JVM 1.8
2222
awaitility = "4.3.0"
2323
mokksy = "0.6.1"
2424

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package io.modelcontextprotocol.kotlin.sdk.shared
2+
3+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
4+
import kotlinx.coroutines.CompletableDeferred
5+
6+
/**
7+
* Implements [onClose], [onError] and [onMessage] functions of [Transport] providing
8+
* corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation.
9+
*/
10+
@Suppress("PropertyName")
11+
public abstract class AbstractTransport : Transport {
12+
protected var _onClose: (() -> Unit) = {}
13+
private set
14+
protected var _onError: ((Throwable) -> Unit) = {}
15+
private set
16+
17+
// to not skip messages
18+
private val _onMessageInitialized = CompletableDeferred<Unit>()
19+
protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {
20+
_onMessageInitialized.await()
21+
_onMessage.invoke(it)
22+
}
23+
private set
24+
25+
override fun onClose(block: () -> Unit) {
26+
val old = _onClose
27+
_onClose = {
28+
old()
29+
block()
30+
}
31+
}
32+
33+
override fun onError(block: (Throwable) -> Unit) {
34+
val old = _onError
35+
_onError = { e ->
36+
old(e)
37+
block(e)
38+
}
39+
}
40+
41+
override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
42+
val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) {
43+
true -> _onMessage
44+
false -> { _ -> }
45+
}
46+
47+
_onMessage = { message ->
48+
old(message)
49+
block(message)
50+
}
51+
52+
_onMessageInitialized.complete(Unit)
53+
}
54+
}

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) }
116116
* Implements MCP protocol framing on top of a pluggable transport, including
117117
* features like request/response linking, notifications, and progress.
118118
*/
119+
@Suppress("TooManyFunctions")
119120
public abstract class Protocol(@PublishedApi internal val options: ProtocolOptions?) {
120121
public var transport: Transport? = null
121122
private set
@@ -190,7 +191,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
190191
/**
191192
* Attaches to the given transport, starts it, and starts listening for messages.
192193
*
193-
* The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward.
194+
* The Protocol object assumes ownership of the Transport,
195+
* replacing any callbacks that have already been set,
196+
* and expects that it is the only user of the Transport instance going forward.
194197
*/
195198
public open suspend fun connect(transport: Transport) {
196199
this.transport = transport
@@ -237,6 +240,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
237240
logger.trace { "No handler found for notification: ${notification.method}" }
238241
return
239242
}
243+
@Suppress("TooGenericExceptionCaught")
240244
try {
241245
handler(notification)
242246
} catch (cause: Throwable) {
@@ -252,6 +256,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
252256

253257
if (handler === null) {
254258
logger.trace { "No handler found for request: ${request.method}" }
259+
@Suppress("TooGenericExceptionCaught")
255260
try {
256261
transport?.send(
257262
JSONRPCResponse(
@@ -269,6 +274,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
269274
return
270275
}
271276

277+
@Suppress("TooGenericExceptionCaught")
272278
try {
273279
val result = handler(request, RequestHandlerExtra())
274280
logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
@@ -303,7 +309,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
303309

304310
private fun onProgress(notification: ProgressNotification) {
305311
logger.trace {
306-
"Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}"
312+
"Received progress notification: token=${notification.params.progressToken}, " +
313+
"progress=${notification.params.progress}/${notification.params.total}"
307314
}
308315
val progress = notification.params.progress
309316
val total = notification.params.total
@@ -392,7 +399,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
392399
public suspend fun <T : RequestResult> request(request: Request, options: RequestOptions? = null): T {
393400
logger.trace { "Sending request: ${request.method}" }
394401
val result = CompletableDeferred<T>()
395-
val transport = transport ?: throw Error("Not connected")
402+
val transport = checkNotNull(transport) {
403+
"No transport connected"
404+
}
396405

397406
if (this@Protocol.options?.enforceStrictCapabilities == true) {
398407
assertCapabilityForMethod(request.method)
@@ -420,6 +429,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
420429
return@put
421430
}
422431

432+
@Suppress("TooGenericExceptionCaught")
423433
try {
424434
@Suppress("UNCHECKED_CAST")
425435
result.complete(response!!.result as T)
Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

33
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
4-
import kotlinx.coroutines.CompletableDeferred
54

65
/**
76
* Describes the minimal contract for MCP transport that a client or server can communicate over.
@@ -47,53 +46,3 @@ public interface Transport {
4746
*/
4847
public fun onMessage(block: suspend (JSONRPCMessage) -> Unit)
4948
}
50-
51-
/**
52-
* Implements [onClose], [onError] and [onMessage] functions of [Transport] providing
53-
* corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation.
54-
*/
55-
@Suppress("PropertyName")
56-
public abstract class AbstractTransport : Transport {
57-
protected var _onClose: (() -> Unit) = {}
58-
private set
59-
protected var _onError: ((Throwable) -> Unit) = {}
60-
private set
61-
62-
// to not skip messages
63-
private val _onMessageInitialized = CompletableDeferred<Unit>()
64-
protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {
65-
_onMessageInitialized.await()
66-
_onMessage.invoke(it)
67-
}
68-
private set
69-
70-
override fun onClose(block: () -> Unit) {
71-
val old = _onClose
72-
_onClose = {
73-
old()
74-
block()
75-
}
76-
}
77-
78-
override fun onError(block: (Throwable) -> Unit) {
79-
val old = _onError
80-
_onError = { e ->
81-
old(e)
82-
block(e)
83-
}
84-
}
85-
86-
override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
87-
val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) {
88-
true -> _onMessage
89-
false -> { _ -> }
90-
}
91-
92-
_onMessage = { message ->
93-
old(message)
94-
block(message)
95-
}
96-
97-
_onMessageInitialized.complete(Unit)
98-
}
99-
}

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ public open class Server(
257257
title: String? = null,
258258
outputSchema: Tool.Output? = null,
259259
toolAnnotations: ToolAnnotations? = null,
260-
@Suppress("LocalVariableName") _meta: JsonObject? = null,
260+
@Suppress("LocalVariableName", "FunctionParameterNaming") _meta: JsonObject? = null,
261261
handler: suspend (CallToolRequest) -> CallToolResult,
262262
) {
263263
val tool = Tool(

kotlin-sdk-test/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ kotlin {
1414
implementation(dependencies.platform(libs.ktor.bom))
1515
implementation(project(":kotlin-sdk"))
1616
implementation(kotlin("test"))
17+
implementation(libs.kotest.assertions.core)
1718
implementation(libs.kotest.assertions.json)
1819
implementation(libs.kotlin.logging)
1920
implementation(libs.kotlinx.coroutines.test)

kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin
22

3+
import io.kotest.matchers.throwable.shouldHaveMessage
34
import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest
45
import io.modelcontextprotocol.kotlin.sdk.GetPromptResult
56
import io.modelcontextprotocol.kotlin.sdk.PromptArgument
@@ -132,7 +133,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
132133
)
133134
}
134135

135-
// complext prompt
136+
// complex prompt
136137
server.addPrompt(
137138
name = complexPromptName,
138139
description = complexPromptDescription,
@@ -152,8 +153,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
152153
// validate required arguments
153154
val requiredArgs = listOf("arg1", "arg2", "arg3")
154155
for (argName in requiredArgs) {
155-
if (request.arguments?.get(argName) == null) {
156-
throw IllegalArgumentException("Missing required argument: $argName")
156+
require(request.arguments?.get(argName) != null) {
157+
"Missing required argument: $argName"
157158
}
158159
}
159160

@@ -665,9 +666,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
665666
}
666667
}
667668

668-
val msg = exception.message ?: ""
669-
val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})"
670-
671-
assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt")
669+
exception shouldHaveMessage
670+
"JSONRPCError(code=InternalError, message=Prompt not found: $nonExistentPromptName, data={})"
672671
}
673672
}

0 commit comments

Comments
 (0)