Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 154 additions & 31 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -252,31 +252,154 @@ public actor Client {
// MARK: - Requests

/// Send a request and receive its response
public func send<M: Method>(_ request: Request<M>) async throws -> M.Result {
guard let connection = connection else {
throw MCPError.internalError("Client connection not initialized")
}
///
/// This method returns a cancellable task that represents the in-flight request.
/// The task can be cancelled at any time, and can be combined with task groups
/// to add timeout behavior.
///
/// ## Examples
///
/// ### Basic tool call
///
/// ```swift
/// let searchResults = try await client.send(CallTool.request(.init(
/// name: "search_web",
/// arguments: ["query": .string("mcp")]
/// ))).value
/// print("Found \(searchResults.content.count) results")
/// ```
///
/// ### Concurrent calculations
///
/// ```swift
/// // Execute calculations concurrently
/// async let firstResult = client.send(CallTool.request(.init(
/// name: "calculate",
/// arguments: ["expression": .string("1 + 1")]
/// ))).value
///
/// async let secondResult = client.send(CallTool.request(.init(
/// name: "calculate",
/// arguments: ["expression": .string("2 + 2")]
/// ))).value
///
/// // Wait for both results and combine them
/// let (result1, result2) = try await (firstResult, secondResult)
///
/// // Extract numeric values from results
/// if case .text(let text1) = result1.content.first, let num1 = Int(text1),
/// case .text(let text2) = result2.content.first, let num2 = Int(text2) {
/// let sum = num1 + num2
/// print("Combined result: \(sum)") // 6
/// }
/// ```
///
/// ### Using TaskGroup for multiple media generations
///
/// ```swift
/// try await withThrowingTaskGroup(of: (String, CallTool.Result?).self) { group in
/// // Add image generation task
/// group.addTask {
/// let result = try await client.send(CallTool.request(.init(
/// name: "generate_image",
/// arguments: ["prompt": .string("sunset over mountains")]
/// ))).value
/// return ("image", result)
/// }
///
/// // Add audio generation task
/// group.addTask {
/// let result = try await client.send(CallTool.request(.init(
/// name: "generate_audio",
/// arguments: ["text": .string("Welcome to the application")]
/// ))).value
/// return ("audio", result)
/// }
///
/// // Add a timeout task that cancels the entire operation after 30 seconds
/// group.addTask {
/// do {
/// try await Task.sleep(for: .seconds(30))
///
/// // Cancel all other tasks in the group
/// group.cancelAll()
///
/// return ("timeout", nil)
/// } catch {
/// return ("timeout", nil)
/// }
/// }
///
/// // Process results as they complete
/// for try await (type, result) in group {
/// if type == "timeout" {
/// print("Operation timed out after 30 seconds")
/// continue
/// }
///
/// guard let result = result else {
/// print("\(type) generation failed")
/// continue
/// }
///
/// switch type {
/// case "image":
/// if case .image(let data, _, _) = result.content.first {
/// print("Image generated: \(data.prefix(10))...")
/// }
/// case "audio":
/// if case .resource(let uri, _, _) = result.content.first {
/// print("Audio available at: \(uri)")
/// }
/// default:
/// break
/// }
/// }
/// }
/// ```
///
/// - Parameters:
/// - request: The request to send
/// - priority: The priority of the task. Defaults to inheriting the current task's priority
/// - Returns: A cancellable task that will complete with the result or throw an error
public func send<M: Method>(
_ request: Request<M>,
priority: TaskPriority? = nil
) -> Task<M.Result, Swift.Error> {
Task(priority: priority) {
guard let connection = connection else {
throw MCPError.internalError("Client connection not initialized")
}

// Use the actor's encoder
let requestData = try encoder.encode(request)
let requestData = try encoder.encode(request)

// Store the pending request first
return try await withCheckedThrowingContinuation { continuation in
Task {
self.addPendingRequest(
id: request.id,
continuation: continuation,
type: M.Result.self
)
// Check for task cancellation before proceeding
try Task.checkCancellation()

// Send the request data
do {
// Use the existing connection send
try await connection.send(requestData)
} catch {
// If send fails immediately, resume continuation and remove pending request
continuation.resume(throwing: error)
self.removePendingRequest(id: request.id) // Ensure cleanup on send error
// Store the pending request
return try await withCheckedThrowingContinuation { continuation in
Task {
do {
// Add the pending request to our tracking dictionary
self.addPendingRequest(
id: request.id,
continuation: continuation,
type: M.Result.self
)

// Send the request data
try await connection.send(requestData)

// Check for cancellation after sending
if Task.isCancelled {
continuation.resume(throwing: CancellationError())
self.removePendingRequest(id: request.id)
}
} catch {
// If send fails immediately, resume continuation and remove pending request
continuation.resume(throwing: error)
self.removePendingRequest(id: request.id)
}
}
}
}
Expand Down Expand Up @@ -448,7 +571,7 @@ public actor Client {
clientInfo: clientInfo
))

let result = try await send(request)
let result = try await send(request).value

self.serverCapabilities = result.capabilities
self.serverVersion = result.protocolVersion
Expand All @@ -459,7 +582,7 @@ public actor Client {

public func ping() async throws {
let request = Ping.request()
_ = try await send(request)
_ = try await send(request).value
}

// MARK: - Prompts
Expand All @@ -469,7 +592,7 @@ public actor Client {
{
try validateServerCapability(\.prompts, "Prompts")
let request = GetPrompt.request(.init(name: name, arguments: arguments))
let result = try await send(request)
let result = try await send(request).value
return (description: result.description, messages: result.messages)
}

Expand All @@ -483,7 +606,7 @@ public actor Client {
} else {
request = ListPrompts.request(.init())
}
let result = try await send(request)
let result = try await send(request).value
return (prompts: result.prompts, nextCursor: result.nextCursor)
}

Expand All @@ -492,7 +615,7 @@ public actor Client {
public func readResource(uri: String) async throws -> [Resource.Content] {
try validateServerCapability(\.resources, "Resources")
let request = ReadResource.request(.init(uri: uri))
let result = try await send(request)
let result = try await send(request).value
return result.contents
}

Expand All @@ -506,14 +629,14 @@ public actor Client {
} else {
request = ListResources.request(.init())
}
let result = try await send(request)
let result = try await send(request).value
return (resources: result.resources, nextCursor: result.nextCursor)
}

public func subscribeToResource(uri: String) async throws {
try validateServerCapability(\.resources?.subscribe, "Resource subscription")
let request = ResourceSubscribe.request(.init(uri: uri))
_ = try await send(request)
_ = try await send(request).value
}

// MARK: - Tools
Expand All @@ -528,7 +651,7 @@ public actor Client {
} else {
request = ListTools.request(.init())
}
let result = try await send(request)
let result = try await send(request).value
return (tools: result.tools, nextCursor: result.nextCursor)
}

Expand All @@ -537,7 +660,7 @@ public actor Client {
) {
try validateServerCapability(\.tools, "Tools")
let request = CallTool.request(.init(name: name, arguments: arguments))
let result = try await send(request)
let result = try await send(request).value
return (content: result.content, isError: result.isError)
}

Expand Down