Skip to content

Commit c62b537

Browse files
committed
Add notifications test
Signed-off-by: Sergey Karpov <[email protected]>
1 parent acc50d3 commit c62b537

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ class TypeScriptClientKotlinServerTest : TypeScriptTestBase() {
5757
)
5858
}
5959

60+
@Test
61+
@Timeout(30, unit = TimeUnit.SECONDS)
62+
fun testNotifications() {
63+
val name = "NotifUser"
64+
val command = "npx tsx myClient.ts $serverUrl multi-greet $name"
65+
val output = executeCommand(command, tsClientDir)
66+
67+
assertTrue(
68+
output.contains("Multiple greetings") || output.contains("greeting"),
69+
"Tool response should contain greeting message",
70+
)
71+
// verify that the server sent 3 notifications
72+
assertTrue(
73+
output.contains("\"notificationCount\": 3") || output.contains("notificationCount: 3"),
74+
"Structured content should indicate that 3 notifications were emitted by the server.\nOutput:\n$output",
75+
)
76+
}
77+
6078
@Test
6179
@Timeout(30, unit = TimeUnit.SECONDS)
6280
fun testToolCallWithSessionManagement() {

kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import io.ktor.server.request.receiveText
1212
import io.ktor.server.response.header
1313
import io.ktor.server.response.respond
1414
import io.ktor.server.response.respondText
15+
import io.ktor.server.response.respondTextWriter
1516
import io.ktor.server.routing.delete
17+
import io.ktor.server.routing.get
1618
import io.ktor.server.routing.post
1719
import io.ktor.server.routing.routing
1820
import io.modelcontextprotocol.kotlin.sdk.CallToolResult
@@ -66,6 +68,20 @@ class KotlinServerForTypeScriptClient {
6668

6769
server = embeddedServer(CIO, port = port) {
6870
routing {
71+
get("/mcp") {
72+
val sessionId = call.request.header("mcp-session-id")
73+
if (sessionId == null) {
74+
call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header")
75+
return@get
76+
}
77+
val transport = serverTransports[sessionId]
78+
if (transport == null) {
79+
call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id")
80+
return@get
81+
}
82+
transport.stream(call)
83+
}
84+
6985
post("/mcp") {
7086
val sessionId = call.request.header("mcp-session-id")
7187
val requestBody = call.receiveText()
@@ -235,6 +251,32 @@ class KotlinServerForTypeScriptClient {
235251
) { request ->
236252
val name = (request.arguments["name"] as? JsonPrimitive)?.content ?: "World"
237253

254+
server.sendToolListChanged()
255+
server.sendLoggingMessage(
256+
io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification(
257+
io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params(
258+
level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info,
259+
data = JsonPrimitive("Preparing greeting for $name")
260+
)
261+
)
262+
)
263+
server.sendLoggingMessage(
264+
io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification(
265+
io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params(
266+
level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info,
267+
data = JsonPrimitive("Halfway there for $name")
268+
)
269+
)
270+
)
271+
server.sendLoggingMessage(
272+
io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification(
273+
io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification.Params(
274+
level = io.modelcontextprotocol.kotlin.sdk.LoggingLevel.info,
275+
data = JsonPrimitive("Done sending greetings to $name")
276+
)
277+
)
278+
)
279+
238280
CallToolResult(
239281
content = listOf(TextContent("Multiple greetings sent to $name!")),
240282
structuredContent = buildJsonObject {
@@ -297,6 +339,30 @@ class HttpServerTransport(private val sessionId: String) : AbstractTransport() {
297339
private val pendingResponses = ConcurrentHashMap<String, CompletableDeferred<JSONRPCMessage>>()
298340
private val messageQueue = Channel<JSONRPCMessage>(Channel.UNLIMITED)
299341

342+
suspend fun stream(call: ApplicationCall) {
343+
logger.debug { "Starting SSE stream for session: $sessionId" }
344+
call.response.header("Cache-Control", "no-cache")
345+
call.response.header("Connection", "keep-alive")
346+
call.respondTextWriter(ContentType.Text.EventStream) {
347+
try {
348+
while (true) {
349+
val result = messageQueue.receiveCatching()
350+
val msg = result.getOrNull() ?: break
351+
val json = McpJson.encodeToString(msg)
352+
write("event: message\n")
353+
write("data: ")
354+
write(json)
355+
write("\n\n")
356+
flush()
357+
}
358+
} catch (e: Exception) {
359+
logger.warn(e) { "SSE stream terminated for session: $sessionId" }
360+
} finally {
361+
logger.debug { "SSE stream closed for session: $sessionId" }
362+
}
363+
}
364+
}
365+
300366
suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) {
301367
try {
302368
logger.info { "Handling request body: $requestBody" }

kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ async function main() {
4444
await client.connect(transport, {protocolVersion: PROTOCOL_VERSION});
4545
console.log('Connected to server');
4646

47+
try {
48+
if (typeof (client as any).on === 'function') {
49+
(client as any).on('notification', (n: any) => {
50+
try {
51+
const method = (n && (n.method || (n.params && n.params.method))) || 'unknown';
52+
console.log('Notification:', method, JSON.stringify(n));
53+
} catch {
54+
console.log('Notification: <unparsable>');
55+
}
56+
});
57+
}
58+
} catch {
59+
}
60+
4761
const toolsResult = await client.listTools();
4862
const tools = toolsResult.tools;
4963
console.log('Available utils:', tools.map((t: { name: any; }) => t.name).join(', '));
@@ -105,4 +119,4 @@ main().catch(error => {
105119
console.error('Unhandled error:', error);
106120
// @ts-ignore
107121
process.exit(1);
108-
});
122+
});

0 commit comments

Comments
 (0)