diff --git a/Sources/MCP/Base/Transports.swift b/Sources/MCP/Base/Transports.swift index edfa042d..b00827a0 100644 --- a/Sources/MCP/Base/Transports.swift +++ b/Sources/MCP/Base/Transports.swift @@ -396,13 +396,7 @@ public actor StdioTransport: Transport { if !receiveContinuationResumed { receiveContinuationResumed = true if let error = error { - if let nwError = error as? NWError { - continuation.resume(throwing: MCP.Error.transportError(nwError)) - } else { - continuation.resume( - throwing: MCP.Error.internalError("Receive error: \(error)") - ) - } + continuation.resume(throwing: MCP.Error.transportError(error)) } else if let content = content { continuation.resume(returning: content) } else { diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index dc45f4ee..453a8fdc 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -87,12 +87,48 @@ public actor Client { /// The task for the message handling loop private var task: Task? + /// An error indicating a type mismatch when decoding a pending request + private struct TypeMismatchError: Swift.Error {} + /// A pending request with a continuation for the result private struct PendingRequest { let continuation: CheckedContinuation } + + /// A type-erased pending request + private struct AnyPendingRequest { + private let _resume: (Result) -> Void + + init(_ request: PendingRequest) { + _resume = { result in + switch result { + case .success(let value): + if let typedValue = value as? T { + request.continuation.resume(returning: typedValue) + } else if let value = value as? Value, + let data = try? JSONEncoder().encode(value), + let decoded = try? JSONDecoder().decode(T.self, from: data) + { + request.continuation.resume(returning: decoded) + } else { + request.continuation.resume(throwing: TypeMismatchError()) + } + case .failure(let error): + request.continuation.resume(throwing: error) + } + } + } + func resume(returning value: Any) { + _resume(.success(value)) + } + + func resume(throwing error: Swift.Error) { + _resume(.failure(error)) + } + } + /// A dictionary of type-erased pending requests, keyed by request ID - private var pendingRequests: [ID: Any] = [:] + private var pendingRequests: [ID: AnyPendingRequest] = [:] public init( name: String, @@ -129,8 +165,10 @@ public actor Client { // Attempt to decode string data as AnyResponse or AnyMessage let decoder = JSONDecoder() - if let response = try? decoder.decode(AnyResponse.self, from: data) { - await handleResponse(response, for: response) + if let response = try? decoder.decode(AnyResponse.self, from: data), + let request = pendingRequests[response.id] + { + await handleResponse(response, for: request) } else if let message = try? decoder.decode(AnyMessage.self, from: data) { await handleMessage(message) } else { @@ -158,11 +196,7 @@ public actor Client { public func disconnect() async { // Cancel all pending requests for (id, request) in pendingRequests { - // We know this cast is safe because we only store PendingRequest values - if let typedRequest = request as? PendingRequest { - typedRequest.continuation.resume( - throwing: Error.internalError("Client disconnected")) - } + request.resume(throwing: Error.internalError("Client disconnected")) pendingRequests.removeValue(forKey: id) } @@ -220,12 +254,12 @@ public actor Client { } } - private func addPendingRequest( + private func addPendingRequest( id: ID, continuation: CheckedContinuation, type: T.Type ) { - pendingRequests[id] = PendingRequest(continuation: continuation) + pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation)) } private func removePendingRequest(id: ID) { @@ -320,19 +354,18 @@ public actor Client { // MARK: - - private func handleResponse(_ response: Response, for request: Any) async { + private func handleResponse(_ response: Response, for request: AnyPendingRequest) + async + { await logger?.debug( "Processing response", metadata: ["id": "\(response.id)"]) - // We know this cast is safe because we only store PendingRequest values - guard let typedRequest = request as? PendingRequest else { return } - switch response.result { case .success(let value): - typedRequest.continuation.resume(returning: value) + request.resume(returning: value) case .failure(let error): - typedRequest.continuation.resume(throwing: error) + request.resume(throwing: error) } removePendingRequest(id: response.id) diff --git a/Tests/MCPTests/RoundtripTests.swift b/Tests/MCPTests/RoundtripTests.swift index 82a97b05..ef314de4 100644 --- a/Tests/MCPTests/RoundtripTests.swift +++ b/Tests/MCPTests/RoundtripTests.swift @@ -8,10 +8,9 @@ import Testing @Suite("Roundtrip Tests") struct RoundtripTests { @Test( - "Initialize roundtrip", .timeLimit(.minutes(1)) ) - func testInitializeRoundtrip() async throws { + func testRoundtrip() async throws { let (clientToServerRead, clientToServerWrite) = try FileDescriptor.pipe() let (serverToClientRead, serverToClientWrite) = try FileDescriptor.pipe() @@ -36,32 +35,83 @@ struct RoundtripTests { version: "1.0.0", capabilities: .init(prompts: .init(), tools: .init()) ) + await server.withMethodHandler(ListTools.self) { _ in + return ListTools.Result(tools: [ + Tool( + name: "add", + description: "Adds two numbers together", + inputSchema: [ + "a": ["type": "integer", "description": "The first number"], + "a": ["type": "integer", "description": "The second number"], + ]) + ]) + } + await server.withMethodHandler(CallTool.self) { request in + guard request.name == "add" else { + return CallTool.Result(content: [.text("Invalid tool name")], isError: true) + } + + guard let a = request.arguments?["a"]?.intValue, + let b = request.arguments?["b"]?.intValue + else { + return CallTool.Result( + content: [.text("Did not receive valid arguments")], isError: true) + } + + return CallTool.Result(content: [.text("\(a + b)")], isError: false) + } + let client = Client(name: "TestClient", version: "1.0") try await server.start(transport: serverTransport) try await client.connect(transport: clientTransport) - // let initTask = Task { - // let result = try await client.initialize() + let initTask = Task { + let result = try await client.initialize() + + #expect(result.serverInfo.name == "TestServer") + #expect(result.serverInfo.version == "1.0.0") + #expect(result.capabilities.prompts != nil) + #expect(result.capabilities.tools != nil) + #expect(result.protocolVersion == Version.latest) + } + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await Task.sleep(for: .seconds(1)) + initTask.cancel() + throw CancellationError() + } + group.addTask { + try await initTask.value + } + try await group.next() + group.cancelAll() + } + + let listToolsTask = Task { + let result = try await client.listTools() + #expect(result.count == 1) + #expect(result[0].name == "add") + } + + let callToolTask = Task { + let result = try await client.callTool(name: "add", arguments: ["a": 1, "b": 2]) + #expect(result.isError == false) + #expect(result.content == [.text("3")]) + } - // #expect(result.serverInfo.name == "TestServer") - // #expect(result.serverInfo.version == "1.0.0") - // #expect(result.capabilities.prompts != nil) - // #expect(result.capabilities.tools != nil) - // #expect(result.protocolVersion == Version.latest) - // } - // try await withThrowingTaskGroup(of: Void.self) { group in - // group.addTask { - // try await Task.sleep(for: .seconds(1)) - // initTask.cancel() - // throw CancellationError() - // } - // group.addTask { - // try await initTask.value - // } - // try await group.next() - // group.cancelAll() - // } + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await Task.sleep(for: .seconds(1)) + listToolsTask.cancel() + throw CancellationError() + } + group.addTask { + try await callToolTask.value + } + try await group.next() + group.cancelAll() + } await server.stop() await client.disconnect()