diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index fa4fbe09..d22dc42a 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -252,31 +252,154 @@ public actor Client { // MARK: - Requests /// Send a request and receive its response - public func send(_ request: Request) async throws -> M.Result { - guard let connection = connection else { - throw MCPError.internalError("Client connection not initialized") - } + /// + /// This method returns a cancellable task that represents the in-flight request. + /// The task can be cancelled at any time, and can be combined with task groups + /// to add timeout behavior. + /// + /// ## Examples + /// + /// ### Basic tool call + /// + /// ```swift + /// let searchResults = try await client.send(CallTool.request(.init( + /// name: "search_web", + /// arguments: ["query": .string("mcp")] + /// ))).value + /// print("Found \(searchResults.content.count) results") + /// ``` + /// + /// ### Concurrent calculations + /// + /// ```swift + /// // Execute calculations concurrently + /// async let firstResult = client.send(CallTool.request(.init( + /// name: "calculate", + /// arguments: ["expression": .string("1 + 1")] + /// ))).value + /// + /// async let secondResult = client.send(CallTool.request(.init( + /// name: "calculate", + /// arguments: ["expression": .string("2 + 2")] + /// ))).value + /// + /// // Wait for both results and combine them + /// let (result1, result2) = try await (firstResult, secondResult) + /// + /// // Extract numeric values from results + /// if case .text(let text1) = result1.content.first, let num1 = Int(text1), + /// case .text(let text2) = result2.content.first, let num2 = Int(text2) { + /// let sum = num1 + num2 + /// print("Combined result: \(sum)") // 6 + /// } + /// ``` + /// + /// ### Using TaskGroup for multiple media generations + /// + /// ```swift + /// try await withThrowingTaskGroup(of: (String, CallTool.Result?).self) { group in + /// // Add image generation task + /// group.addTask { + /// let result = try await client.send(CallTool.request(.init( + /// name: "generate_image", + /// arguments: ["prompt": .string("sunset over mountains")] + /// ))).value + /// return ("image", result) + /// } + /// + /// // Add audio generation task + /// group.addTask { + /// let result = try await client.send(CallTool.request(.init( + /// name: "generate_audio", + /// arguments: ["text": .string("Welcome to the application")] + /// ))).value + /// return ("audio", result) + /// } + /// + /// // Add a timeout task that cancels the entire operation after 30 seconds + /// group.addTask { + /// do { + /// try await Task.sleep(for: .seconds(30)) + /// + /// // Cancel all other tasks in the group + /// group.cancelAll() + /// + /// return ("timeout", nil) + /// } catch { + /// return ("timeout", nil) + /// } + /// } + /// + /// // Process results as they complete + /// for try await (type, result) in group { + /// if type == "timeout" { + /// print("Operation timed out after 30 seconds") + /// continue + /// } + /// + /// guard let result = result else { + /// print("\(type) generation failed") + /// continue + /// } + /// + /// switch type { + /// case "image": + /// if case .image(let data, _, _) = result.content.first { + /// print("Image generated: \(data.prefix(10))...") + /// } + /// case "audio": + /// if case .resource(let uri, _, _) = result.content.first { + /// print("Audio available at: \(uri)") + /// } + /// default: + /// break + /// } + /// } + /// } + /// ``` + /// + /// - Parameters: + /// - request: The request to send + /// - priority: The priority of the task. Defaults to inheriting the current task's priority + /// - Returns: A cancellable task that will complete with the result or throw an error + public func send( + _ request: Request, + priority: TaskPriority? = nil + ) -> Task { + Task(priority: priority) { + guard let connection = connection else { + throw MCPError.internalError("Client connection not initialized") + } - // Use the actor's encoder - let requestData = try encoder.encode(request) + let requestData = try encoder.encode(request) - // Store the pending request first - return try await withCheckedThrowingContinuation { continuation in - Task { - self.addPendingRequest( - id: request.id, - continuation: continuation, - type: M.Result.self - ) + // Check for task cancellation before proceeding + try Task.checkCancellation() - // Send the request data - do { - // Use the existing connection send - try await connection.send(requestData) - } catch { - // If send fails immediately, resume continuation and remove pending request - continuation.resume(throwing: error) - self.removePendingRequest(id: request.id) // Ensure cleanup on send error + // Store the pending request + return try await withCheckedThrowingContinuation { continuation in + Task { + do { + // Add the pending request to our tracking dictionary + self.addPendingRequest( + id: request.id, + continuation: continuation, + type: M.Result.self + ) + + // Send the request data + try await connection.send(requestData) + + // Check for cancellation after sending + if Task.isCancelled { + continuation.resume(throwing: CancellationError()) + self.removePendingRequest(id: request.id) + } + } catch { + // If send fails immediately, resume continuation and remove pending request + continuation.resume(throwing: error) + self.removePendingRequest(id: request.id) + } } } } @@ -448,7 +571,7 @@ public actor Client { clientInfo: clientInfo )) - let result = try await send(request) + let result = try await send(request).value self.serverCapabilities = result.capabilities self.serverVersion = result.protocolVersion @@ -459,7 +582,7 @@ public actor Client { public func ping() async throws { let request = Ping.request() - _ = try await send(request) + _ = try await send(request).value } // MARK: - Prompts @@ -469,7 +592,7 @@ public actor Client { { try validateServerCapability(\.prompts, "Prompts") let request = GetPrompt.request(.init(name: name, arguments: arguments)) - let result = try await send(request) + let result = try await send(request).value return (description: result.description, messages: result.messages) } @@ -483,7 +606,7 @@ public actor Client { } else { request = ListPrompts.request(.init()) } - let result = try await send(request) + let result = try await send(request).value return (prompts: result.prompts, nextCursor: result.nextCursor) } @@ -492,7 +615,7 @@ public actor Client { public func readResource(uri: String) async throws -> [Resource.Content] { try validateServerCapability(\.resources, "Resources") let request = ReadResource.request(.init(uri: uri)) - let result = try await send(request) + let result = try await send(request).value return result.contents } @@ -506,14 +629,14 @@ public actor Client { } else { request = ListResources.request(.init()) } - let result = try await send(request) + let result = try await send(request).value return (resources: result.resources, nextCursor: result.nextCursor) } public func subscribeToResource(uri: String) async throws { try validateServerCapability(\.resources?.subscribe, "Resource subscription") let request = ResourceSubscribe.request(.init(uri: uri)) - _ = try await send(request) + _ = try await send(request).value } // MARK: - Tools @@ -528,7 +651,7 @@ public actor Client { } else { request = ListTools.request(.init()) } - let result = try await send(request) + let result = try await send(request).value return (tools: result.tools, nextCursor: result.nextCursor) } @@ -537,7 +660,7 @@ public actor Client { ) { try validateServerCapability(\.tools, "Tools") let request = CallTool.request(.init(name: name, arguments: arguments)) - let result = try await send(request) + let result = try await send(request).value return (content: result.content, isError: result.isError) }