Skip to content

Commit 4f0decc

Browse files
committed
Improve serializers, add utils method and maps
1 parent 25238ec commit 4f0decc

File tree

1 file changed

+92
-79
lines changed
  • kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types

1 file changed

+92
-79
lines changed

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt

Lines changed: 92 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,16 @@ private val logger = KotlinLogging.logger {}
2929
private fun JsonElement.getMethodOrNull(): String? = jsonObject["method"]?.jsonPrimitive?.content
3030

3131
/**
32-
* Extracts the method field from a JSON element.
33-
* Throws [SerializationException] if the method field is not present.
32+
* Safely extracts the type field from a JSON element.
33+
* Returns null if the type field is not present.
3434
*/
35-
private fun JsonElement.getMethod(): String =
36-
getMethodOrNull() ?: throw SerializationException("Missing required 'method' field in notification")
35+
private fun JsonElement.getTypeOrNull(): String? = jsonObject["type"]?.jsonPrimitive?.content
36+
37+
/**
38+
* Extracts the type field from a JSON element.
39+
* Throws [SerializationException] if the type field is not present.
40+
*/
41+
private fun JsonElement.getType(): String = requireNotNull(getTypeOrNull()) { "Missing required 'type' field" }
3742

3843
// ============================================================================
3944
// Method Serializer
@@ -72,10 +77,10 @@ internal object MethodSerializer : KSerializer<Method> {
7277
*/
7378
internal object ReferencePolymorphicSerializer : JsonContentPolymorphicSerializer<Reference>(Reference::class) {
7479
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<Reference> =
75-
when (element.jsonObject.getValue("type").jsonPrimitive.content) {
80+
when (element.getType()) {
7681
ReferenceType.Prompt.value -> PromptReference.serializer()
7782
ReferenceType.ResourceTemplate.value -> ResourceTemplateReference.serializer()
78-
else -> error("Unknown reference type")
83+
else -> throw SerializationException("Unknown reference type: ${element.getTypeOrNull()}")
7984
}
8085
}
8186

@@ -90,13 +95,13 @@ internal object ReferencePolymorphicSerializer : JsonContentPolymorphicSerialize
9095
internal object ContentBlockPolymorphicSerializer :
9196
JsonContentPolymorphicSerializer<ContentBlock>(ContentBlock::class) {
9297
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<ContentBlock> =
93-
when (element.jsonObject.getValue("type").jsonPrimitive.content) {
98+
when (element.getType()) {
9499
ContentTypes.TEXT.value -> TextContent.serializer()
95100
ContentTypes.IMAGE.value -> ImageContent.serializer()
96101
ContentTypes.AUDIO.value -> AudioContent.serializer()
97102
ContentTypes.RESOURCE_LINK.value -> ResourceLink.serializer()
98103
ContentTypes.EMBEDDED_RESOURCE.value -> EmbeddedResource.serializer()
99-
else -> error("Unknown content block type")
104+
else -> throw SerializationException("Unknown content block type: ${element.getTypeOrNull()}")
100105
}
101106
}
102107

@@ -107,11 +112,11 @@ internal object ContentBlockPolymorphicSerializer :
107112
internal object MediaContentPolymorphicSerializer :
108113
JsonContentPolymorphicSerializer<MediaContent>(MediaContent::class) {
109114
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<MediaContent> =
110-
when (element.jsonObject.getValue("type").jsonPrimitive.content) {
115+
when (element.getType()) {
111116
ContentTypes.TEXT.value -> TextContent.serializer()
112117
ContentTypes.IMAGE.value -> ImageContent.serializer()
113118
ContentTypes.AUDIO.value -> AudioContent.serializer()
114-
else -> error("Unknown media content type")
119+
else -> throw SerializationException("Unknown media content type: ${element.getTypeOrNull()}")
115120
}
116121
}
117122

@@ -128,8 +133,8 @@ internal object ResourceContentsPolymorphicSerializer :
128133
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<ResourceContents> {
129134
val jsonObject = element.jsonObject
130135
return when {
131-
jsonObject.contains("text") -> TextResourceContents.serializer()
132-
jsonObject.contains("blob") -> BlobResourceContents.serializer()
136+
"text" in jsonObject -> TextResourceContents.serializer()
137+
"blob" in jsonObject -> BlobResourceContents.serializer()
133138
else -> UnknownResourceContents.serializer()
134139
}
135140
}
@@ -139,38 +144,46 @@ internal object ResourceContentsPolymorphicSerializer :
139144
// Request Serializers
140145
// ============================================================================
141146

147+
private val clientRequestDeserializers: Map<String, DeserializationStrategy<ClientRequest>> by lazy {
148+
mapOf(
149+
Method.Defined.CompletionComplete.value to CompleteRequest.serializer(),
150+
Method.Defined.Initialize.value to InitializeRequest.serializer(),
151+
Method.Defined.Ping.value to PingRequest.serializer(),
152+
Method.Defined.LoggingSetLevel.value to SetLevelRequest.serializer(),
153+
Method.Defined.PromptsGet.value to GetPromptRequest.serializer(),
154+
Method.Defined.PromptsList.value to ListPromptsRequest.serializer(),
155+
Method.Defined.ResourcesList.value to ListResourcesRequest.serializer(),
156+
Method.Defined.ResourcesRead.value to ReadResourceRequest.serializer(),
157+
Method.Defined.ResourcesSubscribe.value to SubscribeRequest.serializer(),
158+
Method.Defined.ResourcesUnsubscribe.value to UnsubscribeRequest.serializer(),
159+
Method.Defined.ResourcesTemplatesList.value to ListResourceTemplatesRequest.serializer(),
160+
Method.Defined.ToolsCall.value to CallToolRequest.serializer(),
161+
Method.Defined.ToolsList.value to ListToolsRequest.serializer(),
162+
)
163+
}
164+
142165
/**
143166
* Selects the appropriate deserializer for client requests based on the method name.
144167
* Returns null if the method is not a known client request method.
145168
*/
146-
internal fun selectClientRequestDeserializer(method: String): DeserializationStrategy<ClientRequest>? = when (method) {
147-
Method.Defined.CompletionComplete.value -> CompleteRequest.serializer()
148-
Method.Defined.Initialize.value -> InitializeRequest.serializer()
149-
Method.Defined.Ping.value -> PingRequest.serializer()
150-
Method.Defined.LoggingSetLevel.value -> SetLevelRequest.serializer()
151-
Method.Defined.PromptsGet.value -> GetPromptRequest.serializer()
152-
Method.Defined.PromptsList.value -> ListPromptsRequest.serializer()
153-
Method.Defined.ResourcesList.value -> ListResourcesRequest.serializer()
154-
Method.Defined.ResourcesRead.value -> ReadResourceRequest.serializer()
155-
Method.Defined.ResourcesSubscribe.value -> SubscribeRequest.serializer()
156-
Method.Defined.ResourcesUnsubscribe.value -> UnsubscribeRequest.serializer()
157-
Method.Defined.ResourcesTemplatesList.value -> ListResourceTemplatesRequest.serializer()
158-
Method.Defined.ToolsCall.value -> CallToolRequest.serializer()
159-
Method.Defined.ToolsList.value -> ListToolsRequest.serializer()
160-
else -> null
169+
internal fun selectClientRequestDeserializer(method: String): DeserializationStrategy<ClientRequest>? =
170+
clientRequestDeserializers[method]
171+
172+
private val serverRequestDeserializers: Map<String, DeserializationStrategy<ServerRequest>> by lazy {
173+
mapOf(
174+
Method.Defined.ElicitationCreate.value to ElicitRequest.serializer(),
175+
Method.Defined.Ping.value to PingRequest.serializer(),
176+
Method.Defined.RootsList.value to ListRootsRequest.serializer(),
177+
Method.Defined.SamplingCreateMessage.value to CreateMessageRequest.serializer(),
178+
)
161179
}
162180

163181
/**
164182
* Selects the appropriate deserializer for server requests based on the method name.
165183
* Returns null if the method is not a known server request method.
166184
*/
167-
internal fun selectServerRequestDeserializer(method: String): DeserializationStrategy<ServerRequest>? = when (method) {
168-
Method.Defined.ElicitationCreate.value -> ElicitRequest.serializer()
169-
Method.Defined.Ping.value -> PingRequest.serializer()
170-
Method.Defined.RootsList.value -> ListRootsRequest.serializer()
171-
Method.Defined.SamplingCreateMessage.value -> CreateMessageRequest.serializer()
172-
else -> null
173-
}
185+
internal fun selectServerRequestDeserializer(method: String): DeserializationStrategy<ServerRequest>? =
186+
serverRequestDeserializers[method]
174187

175188
/**
176189
* Polymorphic serializer for [Request] types.
@@ -179,8 +192,8 @@ internal fun selectServerRequestDeserializer(method: String): DeserializationStr
179192
internal object RequestPolymorphicSerializer : JsonContentPolymorphicSerializer<Request>(Request::class) {
180193
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<Request> {
181194
val method = element.getMethodOrNull() ?: run {
182-
logger.error { "No method in $element" }
183-
error("No method in $element")
195+
logger.error { "Missing 'method' for Request: $element" }
196+
throw SerializationException("Missing 'method' for Request: $element")
184197
}
185198

186199
return selectClientRequestDeserializer(method)
@@ -193,40 +206,40 @@ internal object RequestPolymorphicSerializer : JsonContentPolymorphicSerializer<
193206
// Notification Serializers
194207
// ============================================================================
195208

209+
private val clientNotificationDeserializers: Map<String, DeserializationStrategy<ClientNotification>> by lazy {
210+
mapOf(
211+
Method.Defined.NotificationsCancelled.value to CancelledNotification.serializer(),
212+
Method.Defined.NotificationsProgress.value to ProgressNotification.serializer(),
213+
Method.Defined.NotificationsInitialized.value to InitializedNotification.serializer(),
214+
Method.Defined.NotificationsRootsListChanged.value to RootsListChangedNotification.serializer(),
215+
)
216+
}
217+
196218
/**
197219
* Selects the appropriate deserializer for client notifications based on the method name.
198220
* Returns null if the method is not a known client notification method.
199221
*/
200-
private fun selectClientNotificationDeserializer(element: JsonElement): DeserializationStrategy<ClientNotification>? {
201-
val method = element.getMethodOrNull() ?: return null
202-
203-
return when (method) {
204-
Method.Defined.NotificationsCancelled.value -> CancelledNotification.serializer()
205-
Method.Defined.NotificationsProgress.value -> ProgressNotification.serializer()
206-
Method.Defined.NotificationsInitialized.value -> InitializedNotification.serializer()
207-
Method.Defined.NotificationsRootsListChanged.value -> RootsListChangedNotification.serializer()
208-
else -> null
209-
}
222+
private fun selectClientNotificationDeserializer(element: JsonElement): DeserializationStrategy<ClientNotification>? =
223+
element.getMethodOrNull()?.let(clientNotificationDeserializers::get)
224+
225+
private val serverNotificationDeserializers: Map<String, DeserializationStrategy<ServerNotification>> by lazy {
226+
mapOf(
227+
Method.Defined.NotificationsCancelled.value to CancelledNotification.serializer(),
228+
Method.Defined.NotificationsProgress.value to ProgressNotification.serializer(),
229+
Method.Defined.NotificationsMessage.value to LoggingMessageNotification.serializer(),
230+
Method.Defined.NotificationsResourcesUpdated.value to ResourceUpdatedNotification.serializer(),
231+
Method.Defined.NotificationsResourcesListChanged.value to ResourceListChangedNotification.serializer(),
232+
Method.Defined.NotificationsToolsListChanged.value to ToolListChangedNotification.serializer(),
233+
Method.Defined.NotificationsPromptsListChanged.value to PromptListChangedNotification.serializer(),
234+
)
210235
}
211236

212237
/**
213238
* Selects the appropriate deserializer for server notifications based on the method name.
214239
* Returns null if the method is not a known server notification method.
215240
*/
216-
internal fun selectServerNotificationDeserializer(element: JsonElement): DeserializationStrategy<ServerNotification>? {
217-
val method = element.getMethodOrNull() ?: return null
218-
219-
return when (method) {
220-
Method.Defined.NotificationsCancelled.value -> CancelledNotification.serializer()
221-
Method.Defined.NotificationsProgress.value -> ProgressNotification.serializer()
222-
Method.Defined.NotificationsMessage.value -> LoggingMessageNotification.serializer()
223-
Method.Defined.NotificationsResourcesUpdated.value -> ResourceUpdatedNotification.serializer()
224-
Method.Defined.NotificationsResourcesListChanged.value -> ResourceListChangedNotification.serializer()
225-
Method.Defined.NotificationsToolsListChanged.value -> ToolListChangedNotification.serializer()
226-
Method.Defined.NotificationsPromptsListChanged.value -> PromptListChangedNotification.serializer()
227-
else -> null
228-
}
229-
}
241+
internal fun selectServerNotificationDeserializer(element: JsonElement): DeserializationStrategy<ServerNotification>? =
242+
element.getMethodOrNull()?.let(serverNotificationDeserializers::get)
230243

231244
/**
232245
* Polymorphic serializer for [Notification] types.
@@ -273,7 +286,7 @@ internal object ServerNotificationPolymorphicSerializer :
273286
private fun selectEmptyResult(element: JsonElement): DeserializationStrategy<EmptyResult>? {
274287
val jsonObject = element.jsonObject
275288
return when {
276-
jsonObject.isEmpty() || (jsonObject.size == 1 && jsonObject.contains("_meta")) -> EmptyResult.serializer()
289+
jsonObject.isEmpty() || (jsonObject.size == 1 && "_meta" in jsonObject) -> EmptyResult.serializer()
277290
else -> null
278291
}
279292
}
@@ -285,9 +298,9 @@ private fun selectEmptyResult(element: JsonElement): DeserializationStrategy<Emp
285298
private fun selectClientResultDeserializer(element: JsonElement): DeserializationStrategy<ClientResult>? {
286299
val jsonObject = element.jsonObject
287300
return when {
288-
jsonObject.contains("model") && jsonObject.contains("role") -> CreateMessageResult.serializer()
289-
jsonObject.contains("roots") -> ListRootsResult.serializer()
290-
jsonObject.contains("action") -> ElicitResult.serializer()
301+
"model" in jsonObject && "role" in jsonObject -> CreateMessageResult.serializer()
302+
"roots" in jsonObject -> ListRootsResult.serializer()
303+
"action" in jsonObject -> ElicitResult.serializer()
291304
else -> null
292305
}
293306
}
@@ -299,15 +312,15 @@ private fun selectClientResultDeserializer(element: JsonElement): Deserializatio
299312
private fun selectServerResultDeserializer(element: JsonElement): DeserializationStrategy<ServerResult>? {
300313
val jsonObject = element.jsonObject
301314
return when {
302-
jsonObject.contains("protocolVersion") && jsonObject.contains("capabilities") -> InitializeResult.serializer()
303-
jsonObject.contains("completion") -> CompleteResult.serializer()
304-
jsonObject.contains("tools") -> ListToolsResult.serializer()
305-
jsonObject.contains("resources") -> ListResourcesResult.serializer()
306-
jsonObject.contains("resourceTemplates") -> ListResourceTemplatesResult.serializer()
307-
jsonObject.contains("prompts") -> ListPromptsResult.serializer()
308-
jsonObject.contains("messages") -> GetPromptResult.serializer()
309-
jsonObject.contains("contents") -> ReadResourceResult.serializer()
310-
jsonObject.contains("content") -> CallToolResult.serializer()
315+
"protocolVersion" in jsonObject && "capabilities" in jsonObject -> InitializeResult.serializer()
316+
"completion" in jsonObject -> CompleteResult.serializer()
317+
"tools" in jsonObject -> ListToolsResult.serializer()
318+
"resources" in jsonObject -> ListResourcesResult.serializer()
319+
"resourceTemplates" in jsonObject -> ListResourceTemplatesResult.serializer()
320+
"prompts" in jsonObject -> ListPromptsResult.serializer()
321+
"messages" in jsonObject -> GetPromptResult.serializer()
322+
"contents" in jsonObject -> ReadResourceResult.serializer()
323+
"content" in jsonObject -> CallToolResult.serializer()
311324
else -> null
312325
}
313326
}
@@ -367,11 +380,11 @@ internal object JSONRPCMessagePolymorphicSerializer :
367380
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<JSONRPCMessage> {
368381
val jsonObject = element.jsonObject
369382
return when {
370-
jsonObject.contains("error") -> JSONRPCError.serializer()
371-
jsonObject.contains("result") -> JSONRPCResponse.serializer()
372-
jsonObject.contains("method") && jsonObject.contains("id") -> JSONRPCRequest.serializer()
373-
jsonObject.contains("method") -> JSONRPCNotification.serializer()
374-
else -> error("Invalid JSONRPCMessage type")
383+
"error" in jsonObject -> JSONRPCError.serializer()
384+
"result" in jsonObject -> JSONRPCResponse.serializer()
385+
"method" in jsonObject && "id" in jsonObject -> JSONRPCRequest.serializer()
386+
"method" in jsonObject -> JSONRPCNotification.serializer()
387+
else -> throw SerializationException("Invalid JSONRPCMessage type: ${jsonObject.keys}")
375388
}
376389
}
377390
}
@@ -384,6 +397,6 @@ internal object RequestIdPolymorphicSerializer : JsonContentPolymorphicSerialize
384397
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<RequestId> = when (element) {
385398
is JsonPrimitive if (element.isString) -> RequestId.StringId.serializer()
386399
is JsonPrimitive if (element.longOrNull != null) -> RequestId.NumberId.serializer()
387-
else -> error("Invalid RequestId type")
400+
else -> throw SerializationException("Invalid RequestId type: $element")
388401
}
389402
}

0 commit comments

Comments
 (0)