Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ kotlinx-io = "0.8.0"
ktor = "3.2.3"
logging = "7.0.13"
slf4j = "2.0.17"
kotest = "6.0.4"
kotest = "5.9.1" # for JVM 1.8
awaitility = "4.3.0"
mokksy = "0.6.1"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package io.modelcontextprotocol.kotlin.sdk.shared

import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import kotlinx.coroutines.CompletableDeferred

/**
* Implements [onClose], [onError] and [onMessage] functions of [Transport] providing
* corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation.
*/
@Suppress("PropertyName")
public abstract class AbstractTransport : Transport {
protected var _onClose: (() -> Unit) = {}
private set
protected var _onError: ((Throwable) -> Unit) = {}
private set

// to not skip messages
private val _onMessageInitialized = CompletableDeferred<Unit>()
protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {
_onMessageInitialized.await()
_onMessage.invoke(it)
}
private set

override fun onClose(block: () -> Unit) {
val old = _onClose
_onClose = {
old()
block()
}
}

override fun onError(block: (Throwable) -> Unit) {
val old = _onError
_onError = { e ->
old(e)
block(e)
}
}

override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) {
true -> _onMessage
false -> { _ -> }
}

_onMessage = { message ->
old(message)
block(message)
}

_onMessageInitialized.complete(Unit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ internal val COMPLETED = CompletableDeferred(Unit).also { it.complete(Unit) }
* Implements MCP protocol framing on top of a pluggable transport, including
* features like request/response linking, notifications, and progress.
*/
@Suppress("TooManyFunctions")
public abstract class Protocol(@PublishedApi internal val options: ProtocolOptions?) {
public var transport: Transport? = null
private set
Expand Down Expand Up @@ -190,7 +191,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
/**
* Attaches to the given transport, starts it, and starts listening for messages.
*
* The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward.
* The Protocol object assumes ownership of the Transport,
* replacing any callbacks that have already been set,
* and expects that it is the only user of the Transport instance going forward.
*/
public open suspend fun connect(transport: Transport) {
this.transport = transport
Expand Down Expand Up @@ -237,6 +240,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
logger.trace { "No handler found for notification: ${notification.method}" }
return
}
@Suppress("TooGenericExceptionCaught")
try {
handler(notification)
} catch (cause: Throwable) {
Expand All @@ -252,6 +256,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio

if (handler === null) {
logger.trace { "No handler found for request: ${request.method}" }
@Suppress("TooGenericExceptionCaught")
try {
transport?.send(
JSONRPCResponse(
Expand All @@ -269,6 +274,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
return
}

@Suppress("TooGenericExceptionCaught")
try {
val result = handler(request, RequestHandlerExtra())
logger.trace { "Request handled successfully: ${request.method} (id: ${request.id})" }
Expand Down Expand Up @@ -303,7 +309,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio

private fun onProgress(notification: ProgressNotification) {
logger.trace {
"Received progress notification: token=${notification.params.progressToken}, progress=${notification.params.progress}/${notification.params.total}"
"Received progress notification: token=${notification.params.progressToken}, " +
"progress=${notification.params.progress}/${notification.params.total}"
}
val progress = notification.params.progress
val total = notification.params.total
Expand Down Expand Up @@ -392,7 +399,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
public suspend fun <T : RequestResult> request(request: Request, options: RequestOptions? = null): T {
logger.trace { "Sending request: ${request.method}" }
val result = CompletableDeferred<T>()
val transport = transport ?: throw Error("Not connected")
val transport = checkNotNull(transport) {
"No transport connected"
}

if ([email protected]?.enforceStrictCapabilities == true) {
assertCapabilityForMethod(request.method)
Expand Down Expand Up @@ -420,6 +429,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
return@put
}

@Suppress("TooGenericExceptionCaught")
try {
@Suppress("UNCHECKED_CAST")
result.complete(response!!.result as T)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.shared

import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import kotlinx.coroutines.CompletableDeferred

/**
* Describes the minimal contract for MCP transport that a client or server can communicate over.
Expand Down Expand Up @@ -47,53 +46,3 @@ public interface Transport {
*/
public fun onMessage(block: suspend (JSONRPCMessage) -> Unit)
}

/**
* Implements [onClose], [onError] and [onMessage] functions of [Transport] providing
* corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation.
*/
@Suppress("PropertyName")
public abstract class AbstractTransport : Transport {
protected var _onClose: (() -> Unit) = {}
private set
protected var _onError: ((Throwable) -> Unit) = {}
private set

// to not skip messages
private val _onMessageInitialized = CompletableDeferred<Unit>()
protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = {
_onMessageInitialized.await()
_onMessage.invoke(it)
}
private set

override fun onClose(block: () -> Unit) {
val old = _onClose
_onClose = {
old()
block()
}
}

override fun onError(block: (Throwable) -> Unit) {
val old = _onError
_onError = { e ->
old(e)
block(e)
}
}

override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) {
true -> _onMessage
false -> { _ -> }
}

_onMessage = { message ->
old(message)
block(message)
}

_onMessageInitialized.complete(Unit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public open class Server(
title: String? = null,
outputSchema: Tool.Output? = null,
toolAnnotations: ToolAnnotations? = null,
@Suppress("LocalVariableName") _meta: JsonObject? = null,
@Suppress("LocalVariableName", "FunctionParameterNaming") _meta: JsonObject? = null,
handler: suspend (CallToolRequest) -> CallToolResult,
) {
val tool = Tool(
Expand Down
1 change: 1 addition & 0 deletions kotlin-sdk-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ kotlin {
implementation(dependencies.platform(libs.ktor.bom))
implementation(project(":kotlin-sdk"))
implementation(kotlin("test"))
implementation(libs.kotest.assertions.core)
implementation(libs.kotest.assertions.json)
implementation(libs.kotlin.logging)
implementation(libs.kotlinx.coroutines.test)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin

import io.kotest.matchers.throwable.shouldHaveMessage
import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest
import io.modelcontextprotocol.kotlin.sdk.GetPromptResult
import io.modelcontextprotocol.kotlin.sdk.PromptArgument
Expand Down Expand Up @@ -132,7 +133,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
)
}

// complext prompt
// complex prompt
server.addPrompt(
name = complexPromptName,
description = complexPromptDescription,
Expand All @@ -152,8 +153,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
// validate required arguments
val requiredArgs = listOf("arg1", "arg2", "arg3")
for (argName in requiredArgs) {
if (request.arguments?.get(argName) == null) {
throw IllegalArgumentException("Missing required argument: $argName")
require(request.arguments?.get(argName) != null) {
"Missing required argument: $argName"
}
}

Expand Down Expand Up @@ -665,9 +666,7 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
}
}

val msg = exception.message ?: ""
val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})"

assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt")
exception shouldHaveMessage
"JSONRPCError(code=InternalError, message=Prompt not found: $nonExistentPromptName, data={})"
}
}
Loading
Loading