Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions Sources/MCP/Base/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@ import Foundation

private let jsonrpc = "2.0"

public struct Empty: Hashable, Codable, Sendable {}
public protocol NotRequired {
init()
}

public struct Empty: NotRequired, Hashable, Codable, Sendable {
public init() {}
}

// MARK: -

Expand Down Expand Up @@ -78,14 +84,11 @@ public struct Request<M: Method>: Hashable, Identifiable, Codable, Sendable {
try container.encode(jsonrpc, forKey: .jsonrpc)
try container.encode(id, forKey: .id)
try container.encode(method, forKey: .method)
if M.Parameters.self != Empty.self {
try container.encode(params, forKey: .params)
} else {
// Encode empty object for Empty parameters
try container.encode(Empty(), forKey: .params)
}
try container.encode(params, forKey: .params)
}
}

extension Request {
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let version = try container.decode(String.self, forKey: .jsonrpc)
Expand All @@ -95,15 +98,11 @@ public struct Request<M: Method>: Hashable, Identifiable, Codable, Sendable {
}
id = try container.decode(ID.self, forKey: .id)
method = try container.decode(String.self, forKey: .method)
if M.Parameters.self == Empty.self {
if (try? container.decodeNil(forKey: .params)) != nil {
params = Empty() as! M.Parameters
} else if (try? container.decode(Empty.self, forKey: .params)) != nil {
params = Empty() as! M.Parameters
} else {
// If params field is missing, use Empty
params = Empty() as! M.Parameters
}

if M.Parameters.self is NotRequired.Type {
params =
(try container.decodeIfPresent(M.Parameters.self, forKey: .params)
?? (M.Parameters.self as! NotRequired.Type).init() as! M.Parameters)
} else {
params = try container.decode(M.Parameters.self, forKey: .params)
}
Expand Down
21 changes: 18 additions & 3 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,12 @@ public actor Client {
-> (prompts: [Prompt], nextCursor: String?)
{
_ = try checkCapability(\.prompts, "Prompts")
let request = ListPrompts.request(.init(cursor: cursor))
let request: Request<ListPrompts>
if let cursor = cursor {
request = ListPrompts.request(.init(cursor: cursor))
} else {
request = ListPrompts.request(.init())
}
let result = try await send(request)
return (prompts: result.prompts, nextCursor: result.nextCursor)
}
Expand All @@ -323,7 +328,12 @@ public actor Client {
resources: [Resource], nextCursor: String?
) {
_ = try checkCapability(\.resources, "Resources")
let request = ListResources.request(.init(cursor: cursor))
let request: Request<ListResources>
if let cursor = cursor {
request = ListResources.request(.init(cursor: cursor))
} else {
request = ListResources.request(.init())
}
let result = try await send(request)
return (resources: result.resources, nextCursor: result.nextCursor)
}
Expand All @@ -338,7 +348,12 @@ public actor Client {

public func listTools(cursor: String? = nil) async throws -> [Tool] {
_ = try checkCapability(\.tools, "Tools")
let request = ListTools.request(.init(cursor: cursor))
let request: Request<ListTools>
if let cursor = cursor {
request = ListTools.request(.init(cursor: cursor))
} else {
request = ListTools.request(.init())
}
let result = try await send(request)
return result.tools
}
Expand Down
8 changes: 6 additions & 2 deletions Sources/MCP/Server/Prompts.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,14 @@ public struct Prompt: Hashable, Codable, Sendable {
public enum ListPrompts: Method {
public static let name: String = "prompts/list"

public struct Parameters: Hashable, Codable, Sendable {
public struct Parameters: NotRequired, Hashable, Codable, Sendable {
public let cursor: String?

public init() {
self.cursor = nil
}

public init(cursor: String? = nil) {
public init(cursor: String) {
self.cursor = cursor
}
}
Expand Down
8 changes: 6 additions & 2 deletions Sources/MCP/Server/Resources.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,14 @@ public struct Resource: Hashable, Codable, Sendable {
public enum ListResources: Method {
public static let name: String = "resources/list"

public struct Parameters: Hashable, Codable, Sendable {
public struct Parameters: NotRequired, Hashable, Codable, Sendable {
public let cursor: String?

public init(cursor: String? = nil) {
public init() {
self.cursor = nil
}

public init(cursor: String) {
self.cursor = cursor
}
}
Expand Down
8 changes: 6 additions & 2 deletions Sources/MCP/Server/Tools.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ public struct Tool: Hashable, Codable, Sendable {
public enum ListTools: Method {
public static let name = "tools/list"

public struct Parameters: Hashable, Codable, Sendable {
public struct Parameters: NotRequired, Hashable, Codable, Sendable {
public let cursor: String?

public init() {
self.cursor = nil
}

public init(cursor: String? = nil) {
public init(cursor: String) {
self.cursor = cursor
}
}
Expand Down
30 changes: 30 additions & 0 deletions Tests/MCPTests/PromptTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,36 @@ struct PromptTests {
let emptyParams = ListPrompts.Parameters()
#expect(emptyParams.cursor == nil)
}

@Test("ListPrompts request decoding with omitted params")
func testListPromptsRequestDecodingWithOmittedParams() throws {
// Test decoding when params field is omitted
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"prompts/list"}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<ListPrompts>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == ListPrompts.name)
}

@Test("ListPrompts request decoding with null params")
func testListPromptsRequestDecodingWithNullParams() throws {
// Test decoding when params field is null
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"prompts/list","params":null}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<ListPrompts>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == ListPrompts.name)
}

@Test("ListPrompts result validation")
func testListPromptsResult() throws {
Expand Down
45 changes: 45 additions & 0 deletions Tests/MCPTests/RequestTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,49 @@ struct RequestTests {
#expect(decoded.id == "test-id")
#expect(decoded.method == EmptyMethod.name)
}

@Test("NotRequired parameters request decoding - with params")
func testNotRequiredParametersRequestDecodingWithParams() throws {
// Test decoding when params field is present
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"ping","params":{}}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<Ping>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == Ping.name)
}

@Test("NotRequired parameters request decoding - without params")
func testNotRequiredParametersRequestDecodingWithoutParams() throws {
// Test decoding when params field is missing
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"ping"}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<Ping>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == Ping.name)
}

@Test("NotRequired parameters request decoding - with null params")
func testNotRequiredParametersRequestDecodingWithNullParams() throws {
// Test decoding when params field is null
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"ping","params":null}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<Ping>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == Ping.name)
}
}
30 changes: 30 additions & 0 deletions Tests/MCPTests/ResourceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,36 @@ struct ResourceTests {
let emptyParams = ListResources.Parameters()
#expect(emptyParams.cursor == nil)
}

@Test("ListResources request decoding with omitted params")
func testListResourcesRequestDecodingWithOmittedParams() throws {
// Test decoding when params field is omitted
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"resources/list"}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<ListResources>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == ListResources.name)
}

@Test("ListResources request decoding with null params")
func testListResourcesRequestDecodingWithNullParams() throws {
// Test decoding when params field is null
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"resources/list","params":null}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<ListResources>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == ListResources.name)
}

@Test("ListResources result validation")
func testListResourcesResult() throws {
Expand Down
20 changes: 20 additions & 0 deletions Tests/MCPTests/RoundtripTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ struct RoundtripTests {
group.cancelAll()
}

// Test ping
let pingTask = Task {
try await client.ping()
// Ping doesn't return anything, so just getting here without throwing is success
#expect(true) // Test passed if we reach this point
}

try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await Task.sleep(for: .seconds(1))
pingTask.cancel()
throw CancellationError()
}
group.addTask {
try await pingTask.value
}
try await group.next()
group.cancelAll()
}

let listToolsTask = Task {
let result = try await client.listTools()
#expect(result.count == 1)
Expand Down
30 changes: 30 additions & 0 deletions Tests/MCPTests/ToolTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ struct ToolTests {
let emptyParams = ListTools.Parameters()
#expect(emptyParams.cursor == nil)
}

@Test("ListTools request decoding with omitted params")
func testListToolsRequestDecodingWithOmittedParams() throws {
// Test decoding when params field is omitted
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"tools/list"}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<ListTools>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == ListTools.name)
}

@Test("ListTools request decoding with null params")
func testListToolsRequestDecodingWithNullParams() throws {
// Test decoding when params field is null
let jsonString = """
{"jsonrpc":"2.0","id":"test-id","method":"tools/list","params":null}
"""
let data = jsonString.data(using: .utf8)!

let decoder = JSONDecoder()
let decoded = try decoder.decode(Request<ListTools>.self, from: data)

#expect(decoded.id == "test-id")
#expect(decoded.method == ListTools.name)
}

@Test("ListTools result validation")
func testListToolsResult() throws {
Expand Down