Skip to content

Commit 14bb443

Browse files
committed
Prevent double continuation resumption in client
Consolidate request removal and continuation resumption logic to ensure each request's continuation is resumed exactly once, preventing "SWIFT TASK CONTINUATION MISUSE" errors during network failures.
1 parent 87f33d0 commit 14bb443

File tree

2 files changed

+183
-34
lines changed

2 files changed

+183
-34
lines changed

Sources/MCP/Client/Client.swift

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,8 @@ public actor Client {
191191
// Try decoding as a batch response first
192192
if let batchResponse = try? decoder.decode([AnyResponse].self, from: data) {
193193
await handleBatchResponse(batchResponse)
194-
} else if let response = try? decoder.decode(AnyResponse.self, from: data),
195-
let request = pendingRequests[response.id]
196-
{
197-
await handleResponse(response, for: request)
194+
} else if let response = try? decoder.decode(AnyResponse.self, from: data) {
195+
await handleResponse(response)
198196
} else if let message = try? decoder.decode(AnyMessage.self, from: data) {
199197
await handleMessage(message)
200198
} else {
@@ -217,23 +215,49 @@ public actor Client {
217215
break
218216
}
219217
} while true
218+
await self.logger?.info("Client message handling loop task is terminating.")
220219
}
221220
}
222221

223222
/// Disconnect the client and cancel all pending requests
224223
public func disconnect() async {
225-
// Cancel all pending requests
226-
for (id, request) in pendingRequests {
224+
await logger?.info("Initiating client disconnect...")
225+
226+
// Part 1: Inside actor - Grab state and clear internal references
227+
let taskToCancel = self.task
228+
let connectionToDisconnect = self.connection
229+
let pendingRequestsToCancel = self.pendingRequests
230+
231+
self.task = nil
232+
self.connection = nil
233+
self.pendingRequests = [:] // Use empty dictionary literal
234+
235+
// Part 2: Outside actor - Resume continuations, disconnect transport, await task
236+
237+
// Resume continuations first
238+
for (_, request) in pendingRequestsToCancel {
227239
request.resume(throwing: MCPError.internalError("Client disconnected"))
228-
pendingRequests.removeValue(forKey: id)
229240
}
241+
await logger?.info("Pending requests cancelled.")
230242

231-
task?.cancel()
232-
task = nil
233-
if let connection = connection {
234-
await connection.disconnect()
243+
// Cancel the task
244+
taskToCancel?.cancel()
245+
await logger?.info("Message loop task cancellation requested.")
246+
247+
// Disconnect the transport *before* awaiting the task
248+
// This should ensure the transport stream is finished, unblocking the loop.
249+
if let conn = connectionToDisconnect {
250+
await conn.disconnect()
251+
await logger?.info("Transport disconnected.")
252+
} else {
253+
await logger?.info("No active transport connection to disconnect.")
235254
}
236-
connection = nil
255+
256+
// Await the task completion *after* transport disconnect
257+
_ = await taskToCancel?.value
258+
await logger?.info("Client message loop task finished.")
259+
260+
await logger?.info("Client disconnect complete.")
237261
}
238262

239263
// MARK: - Registration
@@ -267,12 +291,12 @@ public actor Client {
267291
throw MCPError.internalError("Client connection not initialized")
268292
}
269293

270-
// Use the actor's encoder
271294
let requestData = try encoder.encode(request)
272295

273296
// Store the pending request first
274297
return try await withCheckedThrowingContinuation { continuation in
275298
Task {
299+
// Add the pending request before attempting to send
276300
self.addPendingRequest(
277301
id: request.id,
278302
continuation: continuation,
@@ -284,9 +308,15 @@ public actor Client {
284308
// Use the existing connection send
285309
try await connection.send(requestData)
286310
} catch {
287-
// If send fails immediately, resume continuation and remove pending request
288-
continuation.resume(throwing: error)
289-
self.removePendingRequest(id: request.id) // Ensure cleanup on send error
311+
// If send fails, try to remove the pending request.
312+
// Resume with the send error only if we successfully removed the request,
313+
// indicating the response handler hasn't processed it yet.
314+
if self.removePendingRequest(id: request.id) != nil {
315+
continuation.resume(throwing: error)
316+
}
317+
// Otherwise, the request was already removed by the response handler
318+
// or by disconnect, so the continuation was already resumed.
319+
// Do nothing here.
290320
}
291321
}
292322
}
@@ -300,8 +330,8 @@ public actor Client {
300330
pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation))
301331
}
302332

303-
private func removePendingRequest(id: ID) {
304-
pendingRequests.removeValue(forKey: id)
333+
private func removePendingRequest(id: ID) -> AnyPendingRequest? {
334+
return pendingRequests.removeValue(forKey: id)
305335
}
306336

307337
// MARK: - Batching
@@ -555,21 +585,29 @@ public actor Client {
555585

556586
// MARK: -
557587

558-
private func handleResponse(_ response: Response<AnyMethod>, for request: AnyPendingRequest)
559-
async
560-
{
588+
private func handleResponse(_ response: Response<AnyMethod>) async {
561589
await logger?.debug(
562590
"Processing response",
563591
metadata: ["id": "\(response.id)"])
564592

565-
switch response.result {
566-
case .success(let value):
567-
request.resume(returning: value)
568-
case .failure(let error):
569-
request.resume(throwing: error)
593+
// Attempt to remove the pending request using the response ID.
594+
// Resume with the response only if it hadn't yet been removed.
595+
if let removedRequest = self.removePendingRequest(id: response.id) {
596+
// If we successfully removed it, resume its continuation.
597+
switch response.result {
598+
case .success(let value):
599+
removedRequest.resume(returning: value)
600+
case .failure(let error):
601+
removedRequest.resume(throwing: error)
602+
}
603+
} else {
604+
// Request was already removed (e.g., by send error handler or disconnect).
605+
// Log this, but it's not an error in race condition scenarios.
606+
await logger?.warning(
607+
"Attempted to handle response for already removed request",
608+
metadata: ["id": "\(response.id)"]
609+
)
570610
}
571-
572-
removePendingRequest(id: response.id)
573611
}
574612

575613
private func handleMessage(_ message: Message<AnyNotification>) async {
@@ -619,14 +657,21 @@ public actor Client {
619657
private func handleBatchResponse(_ responses: [AnyResponse]) async {
620658
await logger?.debug("Processing batch response", metadata: ["count": "\(responses.count)"])
621659
for response in responses {
622-
// Look up the pending request for this specific ID within the batch
623-
if let request = pendingRequests[response.id] {
624-
// Reuse the existing single response handler logic
625-
await handleResponse(response, for: request)
660+
// Attempt to remove the pending request.
661+
// If successful, pendingRequest contains the request.
662+
if let pendingRequest = self.removePendingRequest(id: response.id) {
663+
// If we successfully removed it, handle the response using the pending request.
664+
switch response.result {
665+
case .success(let value):
666+
pendingRequest.resume(returning: value)
667+
case .failure(let error):
668+
pendingRequest.resume(throwing: error)
669+
}
626670
} else {
627-
// Log if a response ID doesn't match any pending request
671+
// If removal failed, it means the request ID was not found (or already handled).
672+
// Log a warning.
628673
await logger?.warning(
629-
"Received response in batch for unknown request ID",
674+
"Received response in batch for unknown or already handled request ID",
630675
metadata: ["id": "\(response.id)"]
631676
)
632677
}

Tests/MCPTests/ClientTests.swift

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,4 +479,108 @@ struct ClientTests {
479479

480480
await client.disconnect()
481481
}
482+
483+
@Test("Race condition between send error and response")
484+
func testSendErrorResponseRace() async throws {
485+
let transport = MockTransport()
486+
let client = Client(name: "TestClient", version: "1.0")
487+
488+
try await client.connect(transport: transport)
489+
try await Task.sleep(for: .milliseconds(10))
490+
491+
// Create a ping request
492+
let request = Ping.request()
493+
494+
// Create a task that will send the request
495+
let sendTask = Task {
496+
try await client.ping()
497+
}
498+
499+
// Give it a moment to send the request
500+
try await Task.sleep(for: .milliseconds(10))
501+
502+
// Verify the request was sent
503+
#expect(await transport.sentMessages.count == 1)
504+
505+
// Simulate a network error during send
506+
await transport.setFailSend(true)
507+
508+
// Create a response for the request
509+
let response = Response<Ping>(id: request.id, result: .init())
510+
let anyResponse = try AnyResponse(response)
511+
512+
// Queue the response
513+
try await transport.queue(response: anyResponse)
514+
515+
// Wait for the send task to complete
516+
do {
517+
_ = try await sendTask.value
518+
#expect(Bool(false), "Expected send to fail")
519+
} catch let error as MCPError {
520+
if case .transportError = error {
521+
#expect(Bool(true))
522+
} else {
523+
#expect(Bool(false), "Expected transport error, got \(error)")
524+
}
525+
} catch {
526+
#expect(Bool(false), "Expected MCPError, got \(error)")
527+
}
528+
529+
// Verify no continuation misuse occurred
530+
// (If it did, the test would have crashed)
531+
532+
// await client.disconnect()
533+
}
534+
535+
@Test("Race condition between response and send error")
536+
func testResponseSendErrorRace() async throws {
537+
let transport = MockTransport()
538+
let client = Client(name: "TestClient", version: "1.0")
539+
540+
try await client.connect(transport: transport)
541+
try await Task.sleep(for: .milliseconds(10))
542+
543+
// Create a ping request
544+
let request = Ping.request()
545+
546+
// Create a response for the request
547+
let response = Response<Ping>(id: request.id, result: .init())
548+
let anyResponse = try AnyResponse(response)
549+
550+
// Queue the response before sending the request
551+
try await transport.queue(response: anyResponse)
552+
553+
// Create a task that will send the request
554+
let sendTask = Task {
555+
try await client.ping()
556+
}
557+
558+
// Give it a moment to send the request
559+
try await Task.sleep(for: .milliseconds(10))
560+
561+
// Verify the request was sent
562+
#expect(await transport.sentMessages.count == 1)
563+
564+
// Simulate a network error during send
565+
await transport.setFailSend(true)
566+
567+
// Wait for the send task to complete
568+
do {
569+
_ = try await sendTask.value
570+
#expect(Bool(false), "Expected send to fail")
571+
} catch let error as MCPError {
572+
if case .transportError = error {
573+
#expect(Bool(true))
574+
} else {
575+
#expect(Bool(false), "Expected transport error, got \(error)")
576+
}
577+
} catch {
578+
#expect(Bool(false), "Expected MCPError, got \(error)")
579+
}
580+
581+
// Verify no continuation misuse occurred
582+
// (If it did, the test would have crashed)
583+
584+
await client.disconnect()
585+
}
482586
}

0 commit comments

Comments
 (0)