diff --git a/Sources/MCP/Base/Messages.swift b/Sources/MCP/Base/Messages.swift index 8214a702..fc8e9857 100644 --- a/Sources/MCP/Base/Messages.swift +++ b/Sources/MCP/Base/Messages.swift @@ -2,7 +2,13 @@ import Foundation private let jsonrpc = "2.0" -public struct Empty: Hashable, Codable, Sendable {} +public protocol NotRequired { + init() +} + +public struct Empty: NotRequired, Hashable, Codable, Sendable { + public init() {} +} // MARK: - @@ -78,14 +84,11 @@ public struct Request: Hashable, Identifiable, Codable, Sendable { try container.encode(jsonrpc, forKey: .jsonrpc) try container.encode(id, forKey: .id) try container.encode(method, forKey: .method) - if M.Parameters.self != Empty.self { - try container.encode(params, forKey: .params) - } else { - // Encode empty object for Empty parameters - try container.encode(Empty(), forKey: .params) - } + try container.encode(params, forKey: .params) } +} +extension Request { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) let version = try container.decode(String.self, forKey: .jsonrpc) @@ -95,15 +98,11 @@ public struct Request: Hashable, Identifiable, Codable, Sendable { } id = try container.decode(ID.self, forKey: .id) method = try container.decode(String.self, forKey: .method) - if M.Parameters.self == Empty.self { - if (try? container.decodeNil(forKey: .params)) != nil { - params = Empty() as! M.Parameters - } else if (try? container.decode(Empty.self, forKey: .params)) != nil { - params = Empty() as! M.Parameters - } else { - // If params field is missing, use Empty - params = Empty() as! M.Parameters - } + + if M.Parameters.self is NotRequired.Type { + params = + (try container.decodeIfPresent(M.Parameters.self, forKey: .params) + ?? (M.Parameters.self as! NotRequired.Type).init() as! M.Parameters) } else { params = try container.decode(M.Parameters.self, forKey: .params) } diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index 8f9fa245..7af0fb51 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -305,7 +305,12 @@ public actor Client { -> (prompts: [Prompt], nextCursor: String?) { _ = try checkCapability(\.prompts, "Prompts") - let request = ListPrompts.request(.init(cursor: cursor)) + let request: Request + if let cursor = cursor { + request = ListPrompts.request(.init(cursor: cursor)) + } else { + request = ListPrompts.request(.init()) + } let result = try await send(request) return (prompts: result.prompts, nextCursor: result.nextCursor) } @@ -323,7 +328,12 @@ public actor Client { resources: [Resource], nextCursor: String? ) { _ = try checkCapability(\.resources, "Resources") - let request = ListResources.request(.init(cursor: cursor)) + let request: Request + if let cursor = cursor { + request = ListResources.request(.init(cursor: cursor)) + } else { + request = ListResources.request(.init()) + } let result = try await send(request) return (resources: result.resources, nextCursor: result.nextCursor) } @@ -338,7 +348,12 @@ public actor Client { public func listTools(cursor: String? = nil) async throws -> [Tool] { _ = try checkCapability(\.tools, "Tools") - let request = ListTools.request(.init(cursor: cursor)) + let request: Request + if let cursor = cursor { + request = ListTools.request(.init(cursor: cursor)) + } else { + request = ListTools.request(.init()) + } let result = try await send(request) return result.tools } diff --git a/Sources/MCP/Server/Prompts.swift b/Sources/MCP/Server/Prompts.swift index 859d816d..2ebd3ae0 100644 --- a/Sources/MCP/Server/Prompts.swift +++ b/Sources/MCP/Server/Prompts.swift @@ -153,10 +153,14 @@ public struct Prompt: Hashable, Codable, Sendable { public enum ListPrompts: Method { public static let name: String = "prompts/list" - public struct Parameters: Hashable, Codable, Sendable { + public struct Parameters: NotRequired, Hashable, Codable, Sendable { public let cursor: String? + + public init() { + self.cursor = nil + } - public init(cursor: String? = nil) { + public init(cursor: String) { self.cursor = cursor } } diff --git a/Sources/MCP/Server/Resources.swift b/Sources/MCP/Server/Resources.swift index 72990823..311ff095 100644 --- a/Sources/MCP/Server/Resources.swift +++ b/Sources/MCP/Server/Resources.swift @@ -99,10 +99,14 @@ public struct Resource: Hashable, Codable, Sendable { public enum ListResources: Method { public static let name: String = "resources/list" - public struct Parameters: Hashable, Codable, Sendable { + public struct Parameters: NotRequired, Hashable, Codable, Sendable { public let cursor: String? - public init(cursor: String? = nil) { + public init() { + self.cursor = nil + } + + public init(cursor: String) { self.cursor = cursor } } diff --git a/Sources/MCP/Server/Tools.swift b/Sources/MCP/Server/Tools.swift index aaa07492..74755574 100644 --- a/Sources/MCP/Server/Tools.swift +++ b/Sources/MCP/Server/Tools.swift @@ -118,10 +118,14 @@ public struct Tool: Hashable, Codable, Sendable { public enum ListTools: Method { public static let name = "tools/list" - public struct Parameters: Hashable, Codable, Sendable { + public struct Parameters: NotRequired, Hashable, Codable, Sendable { public let cursor: String? + + public init() { + self.cursor = nil + } - public init(cursor: String? = nil) { + public init(cursor: String) { self.cursor = cursor } } diff --git a/Tests/MCPTests/PromptTests.swift b/Tests/MCPTests/PromptTests.swift index 30651c3c..083d4636 100644 --- a/Tests/MCPTests/PromptTests.swift +++ b/Tests/MCPTests/PromptTests.swift @@ -142,6 +142,36 @@ struct PromptTests { let emptyParams = ListPrompts.Parameters() #expect(emptyParams.cursor == nil) } + + @Test("ListPrompts request decoding with omitted params") + func testListPromptsRequestDecodingWithOmittedParams() throws { + // Test decoding when params field is omitted + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"prompts/list"} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == ListPrompts.name) + } + + @Test("ListPrompts request decoding with null params") + func testListPromptsRequestDecodingWithNullParams() throws { + // Test decoding when params field is null + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"prompts/list","params":null} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == ListPrompts.name) + } @Test("ListPrompts result validation") func testListPromptsResult() throws { diff --git a/Tests/MCPTests/RequestTests.swift b/Tests/MCPTests/RequestTests.swift index add85ac9..7d4e8dde 100644 --- a/Tests/MCPTests/RequestTests.swift +++ b/Tests/MCPTests/RequestTests.swift @@ -76,4 +76,49 @@ struct RequestTests { #expect(decoded.id == "test-id") #expect(decoded.method == EmptyMethod.name) } + + @Test("NotRequired parameters request decoding - with params") + func testNotRequiredParametersRequestDecodingWithParams() throws { + // Test decoding when params field is present + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"ping","params":{}} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == Ping.name) + } + + @Test("NotRequired parameters request decoding - without params") + func testNotRequiredParametersRequestDecodingWithoutParams() throws { + // Test decoding when params field is missing + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"ping"} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == Ping.name) + } + + @Test("NotRequired parameters request decoding - with null params") + func testNotRequiredParametersRequestDecodingWithNullParams() throws { + // Test decoding when params field is null + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"ping","params":null} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == Ping.name) + } } diff --git a/Tests/MCPTests/ResourceTests.swift b/Tests/MCPTests/ResourceTests.swift index eb1096bf..54036327 100644 --- a/Tests/MCPTests/ResourceTests.swift +++ b/Tests/MCPTests/ResourceTests.swift @@ -86,6 +86,36 @@ struct ResourceTests { let emptyParams = ListResources.Parameters() #expect(emptyParams.cursor == nil) } + + @Test("ListResources request decoding with omitted params") + func testListResourcesRequestDecodingWithOmittedParams() throws { + // Test decoding when params field is omitted + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"resources/list"} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == ListResources.name) + } + + @Test("ListResources request decoding with null params") + func testListResourcesRequestDecodingWithNullParams() throws { + // Test decoding when params field is null + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"resources/list","params":null} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == ListResources.name) + } @Test("ListResources result validation") func testListResourcesResult() throws { diff --git a/Tests/MCPTests/RoundtripTests.swift b/Tests/MCPTests/RoundtripTests.swift index 282794e2..49ef0129 100644 --- a/Tests/MCPTests/RoundtripTests.swift +++ b/Tests/MCPTests/RoundtripTests.swift @@ -118,6 +118,26 @@ struct RoundtripTests { group.cancelAll() } + // Test ping + let pingTask = Task { + try await client.ping() + // Ping doesn't return anything, so just getting here without throwing is success + #expect(true) // Test passed if we reach this point + } + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await Task.sleep(for: .seconds(1)) + pingTask.cancel() + throw CancellationError() + } + group.addTask { + try await pingTask.value + } + try await group.next() + group.cancelAll() + } + let listToolsTask = Task { let result = try await client.listTools() #expect(result.count == 1) diff --git a/Tests/MCPTests/ToolTests.swift b/Tests/MCPTests/ToolTests.swift index a29ae670..a9a03fc2 100644 --- a/Tests/MCPTests/ToolTests.swift +++ b/Tests/MCPTests/ToolTests.swift @@ -111,6 +111,36 @@ struct ToolTests { let emptyParams = ListTools.Parameters() #expect(emptyParams.cursor == nil) } + + @Test("ListTools request decoding with omitted params") + func testListToolsRequestDecodingWithOmittedParams() throws { + // Test decoding when params field is omitted + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"tools/list"} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == ListTools.name) + } + + @Test("ListTools request decoding with null params") + func testListToolsRequestDecodingWithNullParams() throws { + // Test decoding when params field is null + let jsonString = """ + {"jsonrpc":"2.0","id":"test-id","method":"tools/list","params":null} + """ + let data = jsonString.data(using: .utf8)! + + let decoder = JSONDecoder() + let decoded = try decoder.decode(Request.self, from: data) + + #expect(decoded.id == "test-id") + #expect(decoded.method == ListTools.name) + } @Test("ListTools result validation") func testListToolsResult() throws {