diff --git a/Sources/MCP/Base/Error.swift b/Sources/MCP/Base/Error.swift index f1d4dcc3..ee2f768b 100644 --- a/Sources/MCP/Base/Error.swift +++ b/Sources/MCP/Base/Error.swift @@ -21,6 +21,7 @@ public enum MCPError: Swift.Error, Sendable { // Transport specific errors case connectionClosed case transportError(Swift.Error) + case requestTimedOut(String?) /// The JSON-RPC 2.0 error code public var code: Int { @@ -33,6 +34,7 @@ public enum MCPError: Swift.Error, Sendable { case .serverError(let code, _): return code case .connectionClosed: return -32000 case .transportError: return -32001 + case .requestTimedOut: return -32002 } } @@ -72,6 +74,8 @@ extension MCPError: LocalizedError { return "Connection closed" case .transportError(let error): return "Transport error: \(error.localizedDescription)" + case .requestTimedOut(let detail): + return "Request timed out" + (detail.map { ": \($0)" } ?? "") } } @@ -93,6 +97,8 @@ extension MCPError: LocalizedError { return "The connection to the server was closed" case .transportError(let error): return (error as? LocalizedError)?.failureReason ?? error.localizedDescription + case .requestTimedOut: + return "Request exceeded the client-side timeout duration, default time is 10 seconds" } } @@ -108,6 +114,8 @@ extension MCPError: LocalizedError { return "Verify the parameters match the method's expected parameters" case .connectionClosed: return "Try reconnecting to the server" + case .requestTimedOut: + return "Try sending the request again, or increase the timeout if necessary" default: return nil } @@ -147,7 +155,8 @@ extension MCPError: Codable { .invalidRequest(let detail), .methodNotFound(let detail), .invalidParams(let detail), - .internalError(let detail): + .internalError(let detail), + .requestTimedOut(let detail): if let detail = detail { try container.encode(["detail": detail], forKey: .data) } @@ -204,6 +213,8 @@ extension MCPError: Codable { userInfo: [NSLocalizedDescriptionKey: underlyingErrorString] ) ) + case -32002: + self = .requestTimedOut(unwrapDetail(message)) default: self = .serverError(code: code, message: message) } @@ -240,6 +251,8 @@ extension MCPError: Hashable { break case .transportError(let error): hasher.combine(error.localizedDescription) + case .requestTimedOut(let detail): + hasher.combine(detail) } } } diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index fa4fbe09..d27134a5 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -252,7 +252,9 @@ public actor Client { // MARK: - Requests /// Send a request and receive its response - public func send(_ request: Request) async throws -> M.Result { + public func send(_ request: Request, timeout: Duration = .seconds(10.0)) + async throws -> M.Result + { guard let connection = connection else { throw MCPError.internalError("Client connection not initialized") } @@ -262,21 +264,46 @@ public actor Client { // Store the pending request first return try await withCheckedThrowingContinuation { continuation in - Task { - self.addPendingRequest( - id: request.id, - continuation: continuation, - type: M.Result.self - ) + self.addPendingRequest( + id: request.id, + continuation: continuation, + type: M.Result.self + ) + + // Send the request data + var sendRequestTask: Task? = nil - // Send the request data + // A timeout task is created to remove a request if it is still pending after time out duration + var timeoutTask: Task? = nil + + sendRequestTask = Task { do { // Use the existing connection send try await connection.send(requestData) } catch { - // If send fails immediately, resume continuation and remove pending request - continuation.resume(throwing: error) + // If send fails immediately, remove pending request and cancel timeout task self.removePendingRequest(id: request.id) // Ensure cleanup on send error + timeoutTask?.cancel() + continuation.resume(throwing: error) + } + } + + timeoutTask = Task { + do { + try await Task.sleep(until: .now + timeout) + + // If timed out, remove pending request and cancel send request task + if self.pendingRequests.keys.contains(request.id) { + self.removePendingRequest(id: request.id) // Ensure cleanup on send error + sendRequestTask?.cancel() + continuation.resume( + throwing: MCPError.requestTimedOut( + "Request timed out after \(timeout)" + ) + ) + } + } catch { + // Do nothing here if the task is cancaled } } } @@ -457,9 +484,13 @@ public actor Client { return result } - public func ping() async throws { + public func ping(timeout: Duration? = nil) async throws { let request = Ping.request() - _ = try await send(request) + if let timeout { + _ = try await send(request, timeout: timeout) + } else { + _ = try await send(request) + } } // MARK: - Prompts diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index e7e2c023..bdb9778b 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -363,4 +363,107 @@ struct ClientTests { await client.disconnect() } + + @Test("Request timeout - request should time out if server does not respond") + func testRequestTimesOut() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + do { + // Do not queue any response on the transport + // so the client never receives a response + try await client.ping(timeout: .milliseconds(100)) + #expect(Bool(false), "Expected request to time out, but it succeeded") + + } catch let error as MCPError { + switch error { + case .requestTimedOut(let detail): + // This is the expected error + #expect(Bool(true), "Got requestTimedOut as expected: \(detail ?? "")") + default: + // If it is a different MCPError, fail + #expect(Bool(false), "Expected requestTimedOut, got \(error)") + } + } catch { + #expect(Bool(false), "Expected an MCPError, but got \(error)") + } + + await client.disconnect() + } + + @Test("Request timeout - request should time out if server responds too late") + func testRequestTimesOutIfResponseIsLate() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + + // Prepare a ping which will be sent with a short 100ms timeout + let request = Ping.request() + + // Prepare a task to queue a response after 200ms + // which is beyond the 100ms timeout + Task { + try? await Task.sleep(for: .milliseconds(200)) + let response = Response(id: request.id, result: .init()) + let anyResponse = try? AnyResponse(response) + if let anyResponse { + try? await transport.queue(response: anyResponse) + } else { + #expect(Bool(false), "Failed to produce any response") + } + } + + do { + try await _ = client.send(request, timeout: .milliseconds(100)) + #expect(Bool(false), "Expected request to time out, but it succeeded") + } catch let error as MCPError { + switch error { + case .requestTimedOut(let detail): + // This is the expected error + #expect(Bool(true), "Got requestTimedOut as expected: \(detail ?? "")") + default: + // If it is a different MCPError, fail + #expect(Bool(false), "Expected requestTimedOut, got \(error)") + } + } catch { + #expect(Bool(false), "Expected an MCPError, but got \(error)") + } + + await client.disconnect() + } + + @Test("Request timeout - request should succeed if server responds before timeout") + func testRequestDoesNotTimeOutIfResponseIsFast() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + + // Prepare a ping which will be sent with a short 200ms timeout + let request = Ping.request() + + // Prepare a task to queue a response after 100ms + // which is less than the specified timeout (200ms) + Task { + try? await Task.sleep(for: .milliseconds(100)) + let response = Response(id: request.id, result: .init()) + let anyResponse = try? AnyResponse(response) + if let anyResponse { + try? await transport.queue(response: anyResponse) + } else { + #expect(Bool(false), "Failed to produce any response") + } + } + + do { + _ = try await client.send(request, timeout: .milliseconds(200)) + #expect(Bool(true), "Request succeeded before timeout") + } catch let error { + #expect(Bool(false), "Did not expect an error here, but got \(error)") + } + + await client.disconnect() + } }