Skip to content

Commit c81abff

Browse files
committed
Update Client.send to return Task instead of result
1 parent 82f4fd2 commit c81abff

File tree

1 file changed

+154
-31
lines changed

1 file changed

+154
-31
lines changed

Sources/MCP/Client/Client.swift

Lines changed: 154 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -252,31 +252,154 @@ public actor Client {
252252
// MARK: - Requests
253253

254254
/// Send a request and receive its response
255-
public func send<M: Method>(_ request: Request<M>) async throws -> M.Result {
256-
guard let connection = connection else {
257-
throw MCPError.internalError("Client connection not initialized")
258-
}
255+
///
256+
/// This method returns a cancellable task that represents the in-flight request.
257+
/// The task can be cancelled at any time, and can be combined with task groups
258+
/// to add timeout behavior.
259+
///
260+
/// ## Examples
261+
///
262+
/// ### Basic tool call
263+
///
264+
/// ```swift
265+
/// let searchResults = try await client.send(CallTool.request(.init(
266+
/// name: "search_web",
267+
/// arguments: ["query": .string("mcp")]
268+
/// ))).value
269+
/// print("Found \(searchResults.content.count) results")
270+
/// ```
271+
///
272+
/// ### Concurrent calculations
273+
///
274+
/// ```swift
275+
/// // Execute calculations concurrently
276+
/// async let firstResult = client.send(CallTool.request(.init(
277+
/// name: "calculate",
278+
/// arguments: ["expression": .string("1 + 1")]
279+
/// ))).value
280+
///
281+
/// async let secondResult = client.send(CallTool.request(.init(
282+
/// name: "calculate",
283+
/// arguments: ["expression": .string("2 + 2")]
284+
/// ))).value
285+
///
286+
/// // Wait for both results and combine them
287+
/// let (result1, result2) = try await (firstResult, secondResult)
288+
///
289+
/// // Extract numeric values from results
290+
/// if case .text(let text1) = result1.content.first, let num1 = Int(text1),
291+
/// case .text(let text2) = result2.content.first, let num2 = Int(text2) {
292+
/// let sum = num1 + num2
293+
/// print("Combined result: \(sum)") // 6
294+
/// }
295+
/// ```
296+
///
297+
/// ### Using TaskGroup for multiple media generations
298+
///
299+
/// ```swift
300+
/// try await withThrowingTaskGroup(of: (String, CallTool.Result?).self) { group in
301+
/// // Add image generation task
302+
/// group.addTask {
303+
/// let result = try await client.send(CallTool.request(.init(
304+
/// name: "generate_image",
305+
/// arguments: ["prompt": .string("sunset over mountains")]
306+
/// ))).value
307+
/// return ("image", result)
308+
/// }
309+
///
310+
/// // Add audio generation task
311+
/// group.addTask {
312+
/// let result = try await client.send(CallTool.request(.init(
313+
/// name: "generate_audio",
314+
/// arguments: ["text": .string("Welcome to the application")]
315+
/// ))).value
316+
/// return ("audio", result)
317+
/// }
318+
///
319+
/// // Add a timeout task that cancels the entire operation after 30 seconds
320+
/// group.addTask {
321+
/// do {
322+
/// try await Task.sleep(for: .seconds(30))
323+
///
324+
/// // Cancel all other tasks in the group
325+
/// group.cancelAll()
326+
///
327+
/// return ("timeout", nil)
328+
/// } catch {
329+
/// return ("timeout", nil)
330+
/// }
331+
/// }
332+
///
333+
/// // Process results as they complete
334+
/// for try await (type, result) in group {
335+
/// if type == "timeout" {
336+
/// print("Operation timed out after 30 seconds")
337+
/// continue
338+
/// }
339+
///
340+
/// guard let result = result else {
341+
/// print("\(type) generation failed")
342+
/// continue
343+
/// }
344+
///
345+
/// switch type {
346+
/// case "image":
347+
/// if case .image(let data, _, _) = result.content.first {
348+
/// print("Image generated: \(data.prefix(10))...")
349+
/// }
350+
/// case "audio":
351+
/// if case .resource(let uri, _, _) = result.content.first {
352+
/// print("Audio available at: \(uri)")
353+
/// }
354+
/// default:
355+
/// break
356+
/// }
357+
/// }
358+
/// }
359+
/// ```
360+
///
361+
/// - Parameters:
362+
/// - request: The request to send
363+
/// - priority: The priority of the task. Defaults to inheriting the current task's priority
364+
/// - Returns: A cancellable task that will complete with the result or throw an error
365+
public func send<M: Method>(
366+
_ request: Request<M>,
367+
priority: TaskPriority? = nil
368+
) -> Task<M.Result, Swift.Error> {
369+
Task(priority: priority) {
370+
guard let connection = connection else {
371+
throw MCPError.internalError("Client connection not initialized")
372+
}
259373

260-
// Use the actor's encoder
261-
let requestData = try encoder.encode(request)
374+
let requestData = try encoder.encode(request)
262375

263-
// Store the pending request first
264-
return try await withCheckedThrowingContinuation { continuation in
265-
Task {
266-
self.addPendingRequest(
267-
id: request.id,
268-
continuation: continuation,
269-
type: M.Result.self
270-
)
376+
// Check for task cancellation before proceeding
377+
try Task.checkCancellation()
271378

272-
// Send the request data
273-
do {
274-
// Use the existing connection send
275-
try await connection.send(requestData)
276-
} catch {
277-
// If send fails immediately, resume continuation and remove pending request
278-
continuation.resume(throwing: error)
279-
self.removePendingRequest(id: request.id) // Ensure cleanup on send error
379+
// Store the pending request
380+
return try await withCheckedThrowingContinuation { continuation in
381+
Task {
382+
do {
383+
// Add the pending request to our tracking dictionary
384+
self.addPendingRequest(
385+
id: request.id,
386+
continuation: continuation,
387+
type: M.Result.self
388+
)
389+
390+
// Send the request data
391+
try await connection.send(requestData)
392+
393+
// Check for cancellation after sending
394+
if Task.isCancelled {
395+
continuation.resume(throwing: CancellationError())
396+
self.removePendingRequest(id: request.id)
397+
}
398+
} catch {
399+
// If send fails immediately, resume continuation and remove pending request
400+
continuation.resume(throwing: error)
401+
self.removePendingRequest(id: request.id)
402+
}
280403
}
281404
}
282405
}
@@ -448,7 +571,7 @@ public actor Client {
448571
clientInfo: clientInfo
449572
))
450573

451-
let result = try await send(request)
574+
let result = try await send(request).value
452575

453576
self.serverCapabilities = result.capabilities
454577
self.serverVersion = result.protocolVersion
@@ -459,7 +582,7 @@ public actor Client {
459582

460583
public func ping() async throws {
461584
let request = Ping.request()
462-
_ = try await send(request)
585+
_ = try await send(request).value
463586
}
464587

465588
// MARK: - Prompts
@@ -469,7 +592,7 @@ public actor Client {
469592
{
470593
try validateServerCapability(\.prompts, "Prompts")
471594
let request = GetPrompt.request(.init(name: name, arguments: arguments))
472-
let result = try await send(request)
595+
let result = try await send(request).value
473596
return (description: result.description, messages: result.messages)
474597
}
475598

@@ -483,7 +606,7 @@ public actor Client {
483606
} else {
484607
request = ListPrompts.request(.init())
485608
}
486-
let result = try await send(request)
609+
let result = try await send(request).value
487610
return (prompts: result.prompts, nextCursor: result.nextCursor)
488611
}
489612

@@ -492,7 +615,7 @@ public actor Client {
492615
public func readResource(uri: String) async throws -> [Resource.Content] {
493616
try validateServerCapability(\.resources, "Resources")
494617
let request = ReadResource.request(.init(uri: uri))
495-
let result = try await send(request)
618+
let result = try await send(request).value
496619
return result.contents
497620
}
498621

@@ -506,14 +629,14 @@ public actor Client {
506629
} else {
507630
request = ListResources.request(.init())
508631
}
509-
let result = try await send(request)
632+
let result = try await send(request).value
510633
return (resources: result.resources, nextCursor: result.nextCursor)
511634
}
512635

513636
public func subscribeToResource(uri: String) async throws {
514637
try validateServerCapability(\.resources?.subscribe, "Resource subscription")
515638
let request = ResourceSubscribe.request(.init(uri: uri))
516-
_ = try await send(request)
639+
_ = try await send(request).value
517640
}
518641

519642
// MARK: - Tools
@@ -528,7 +651,7 @@ public actor Client {
528651
} else {
529652
request = ListTools.request(.init())
530653
}
531-
let result = try await send(request)
654+
let result = try await send(request).value
532655
return (tools: result.tools, nextCursor: result.nextCursor)
533656
}
534657

@@ -537,7 +660,7 @@ public actor Client {
537660
) {
538661
try validateServerCapability(\.tools, "Tools")
539662
let request = CallTool.request(.init(name: name, arguments: arguments))
540-
let result = try await send(request)
663+
let result = try await send(request).value
541664
return (content: result.content, isError: result.isError)
542665
}
543666

0 commit comments

Comments
 (0)