From 14bb443f4e1d82d3bfc1efef6252b70b9c8e1e39 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 12 May 2025 03:26:59 -0700 Subject: [PATCH 1/2] Prevent double continuation resumption in client Consolidate request removal and continuation resumption logic to ensure each request's continuation is resumed exactly once, preventing "SWIFT TASK CONTINUATION MISUSE" errors during network failures. --- Sources/MCP/Client/Client.swift | 113 +++++++++++++++++++++---------- Tests/MCPTests/ClientTests.swift | 104 ++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 34 deletions(-) diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index f6ec77a3..f45b4b96 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -191,10 +191,8 @@ public actor Client { // Try decoding as a batch response first if let batchResponse = try? decoder.decode([AnyResponse].self, from: data) { await handleBatchResponse(batchResponse) - } else if let response = try? decoder.decode(AnyResponse.self, from: data), - let request = pendingRequests[response.id] - { - await handleResponse(response, for: request) + } else if let response = try? decoder.decode(AnyResponse.self, from: data) { + await handleResponse(response) } else if let message = try? decoder.decode(AnyMessage.self, from: data) { await handleMessage(message) } else { @@ -217,23 +215,49 @@ public actor Client { break } } while true + await self.logger?.info("Client message handling loop task is terminating.") } } /// Disconnect the client and cancel all pending requests public func disconnect() async { - // Cancel all pending requests - for (id, request) in pendingRequests { + await logger?.info("Initiating client disconnect...") + + // Part 1: Inside actor - Grab state and clear internal references + let taskToCancel = self.task + let connectionToDisconnect = self.connection + let pendingRequestsToCancel = self.pendingRequests + + self.task = nil + self.connection = nil + self.pendingRequests = [:] // Use empty dictionary literal + + // Part 2: Outside actor - Resume continuations, disconnect transport, await task + + // Resume continuations first + for (_, request) in pendingRequestsToCancel { request.resume(throwing: MCPError.internalError("Client disconnected")) - pendingRequests.removeValue(forKey: id) } + await logger?.info("Pending requests cancelled.") - task?.cancel() - task = nil - if let connection = connection { - await connection.disconnect() + // Cancel the task + taskToCancel?.cancel() + await logger?.info("Message loop task cancellation requested.") + + // Disconnect the transport *before* awaiting the task + // This should ensure the transport stream is finished, unblocking the loop. + if let conn = connectionToDisconnect { + await conn.disconnect() + await logger?.info("Transport disconnected.") + } else { + await logger?.info("No active transport connection to disconnect.") } - connection = nil + + // Await the task completion *after* transport disconnect + _ = await taskToCancel?.value + await logger?.info("Client message loop task finished.") + + await logger?.info("Client disconnect complete.") } // MARK: - Registration @@ -267,12 +291,12 @@ public actor Client { throw MCPError.internalError("Client connection not initialized") } - // Use the actor's encoder let requestData = try encoder.encode(request) // Store the pending request first return try await withCheckedThrowingContinuation { continuation in Task { + // Add the pending request before attempting to send self.addPendingRequest( id: request.id, continuation: continuation, @@ -284,9 +308,15 @@ public actor Client { // Use the existing connection send try await connection.send(requestData) } catch { - // If send fails immediately, resume continuation and remove pending request - continuation.resume(throwing: error) - self.removePendingRequest(id: request.id) // Ensure cleanup on send error + // If send fails, try to remove the pending request. + // Resume with the send error only if we successfully removed the request, + // indicating the response handler hasn't processed it yet. + if self.removePendingRequest(id: request.id) != nil { + continuation.resume(throwing: error) + } + // Otherwise, the request was already removed by the response handler + // or by disconnect, so the continuation was already resumed. + // Do nothing here. } } } @@ -300,8 +330,8 @@ public actor Client { pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation)) } - private func removePendingRequest(id: ID) { - pendingRequests.removeValue(forKey: id) + private func removePendingRequest(id: ID) -> AnyPendingRequest? { + return pendingRequests.removeValue(forKey: id) } // MARK: - Batching @@ -555,21 +585,29 @@ public actor Client { // MARK: - - private func handleResponse(_ response: Response, for request: AnyPendingRequest) - async - { + private func handleResponse(_ response: Response) async { await logger?.debug( "Processing response", metadata: ["id": "\(response.id)"]) - switch response.result { - case .success(let value): - request.resume(returning: value) - case .failure(let error): - request.resume(throwing: error) + // Attempt to remove the pending request using the response ID. + // Resume with the response only if it hadn't yet been removed. + if let removedRequest = self.removePendingRequest(id: response.id) { + // If we successfully removed it, resume its continuation. + switch response.result { + case .success(let value): + removedRequest.resume(returning: value) + case .failure(let error): + removedRequest.resume(throwing: error) + } + } else { + // Request was already removed (e.g., by send error handler or disconnect). + // Log this, but it's not an error in race condition scenarios. + await logger?.warning( + "Attempted to handle response for already removed request", + metadata: ["id": "\(response.id)"] + ) } - - removePendingRequest(id: response.id) } private func handleMessage(_ message: Message) async { @@ -619,14 +657,21 @@ public actor Client { private func handleBatchResponse(_ responses: [AnyResponse]) async { await logger?.debug("Processing batch response", metadata: ["count": "\(responses.count)"]) for response in responses { - // Look up the pending request for this specific ID within the batch - if let request = pendingRequests[response.id] { - // Reuse the existing single response handler logic - await handleResponse(response, for: request) + // Attempt to remove the pending request. + // If successful, pendingRequest contains the request. + if let pendingRequest = self.removePendingRequest(id: response.id) { + // If we successfully removed it, handle the response using the pending request. + switch response.result { + case .success(let value): + pendingRequest.resume(returning: value) + case .failure(let error): + pendingRequest.resume(throwing: error) + } } else { - // Log if a response ID doesn't match any pending request + // If removal failed, it means the request ID was not found (or already handled). + // Log a warning. await logger?.warning( - "Received response in batch for unknown request ID", + "Received response in batch for unknown or already handled request ID", metadata: ["id": "\(response.id)"] ) } diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index cfbedeb1..45345a04 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -479,4 +479,108 @@ struct ClientTests { await client.disconnect() } + + @Test("Race condition between send error and response") + func testSendErrorResponseRace() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) + + // Create a ping request + let request = Ping.request() + + // Create a task that will send the request + let sendTask = Task { + try await client.ping() + } + + // Give it a moment to send the request + try await Task.sleep(for: .milliseconds(10)) + + // Verify the request was sent + #expect(await transport.sentMessages.count == 1) + + // Simulate a network error during send + await transport.setFailSend(true) + + // Create a response for the request + let response = Response(id: request.id, result: .init()) + let anyResponse = try AnyResponse(response) + + // Queue the response + try await transport.queue(response: anyResponse) + + // Wait for the send task to complete + do { + _ = try await sendTask.value + #expect(Bool(false), "Expected send to fail") + } catch let error as MCPError { + if case .transportError = error { + #expect(Bool(true)) + } else { + #expect(Bool(false), "Expected transport error, got \(error)") + } + } catch { + #expect(Bool(false), "Expected MCPError, got \(error)") + } + + // Verify no continuation misuse occurred + // (If it did, the test would have crashed) + +// await client.disconnect() + } + + @Test("Race condition between response and send error") + func testResponseSendErrorRace() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) + + // Create a ping request + let request = Ping.request() + + // Create a response for the request + let response = Response(id: request.id, result: .init()) + let anyResponse = try AnyResponse(response) + + // Queue the response before sending the request + try await transport.queue(response: anyResponse) + + // Create a task that will send the request + let sendTask = Task { + try await client.ping() + } + + // Give it a moment to send the request + try await Task.sleep(for: .milliseconds(10)) + + // Verify the request was sent + #expect(await transport.sentMessages.count == 1) + + // Simulate a network error during send + await transport.setFailSend(true) + + // Wait for the send task to complete + do { + _ = try await sendTask.value + #expect(Bool(false), "Expected send to fail") + } catch let error as MCPError { + if case .transportError = error { + #expect(Bool(true)) + } else { + #expect(Bool(false), "Expected transport error, got \(error)") + } + } catch { + #expect(Bool(false), "Expected MCPError, got \(error)") + } + + // Verify no continuation misuse occurred + // (If it did, the test would have crashed) + + await client.disconnect() + } } From 1f7e722e26236eceedd2e31f8230a47d6571af98 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 26 May 2025 04:31:49 -0700 Subject: [PATCH 2/2] Fix race condition test --- Tests/MCPTests/ClientTests.swift | 60 ++++++++++---------------------- 1 file changed, 19 insertions(+), 41 deletions(-) diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index 45345a04..9e649250 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -488,33 +488,21 @@ struct ClientTests { try await client.connect(transport: transport) try await Task.sleep(for: .milliseconds(10)) - // Create a ping request - let request = Ping.request() - - // Create a task that will send the request - let sendTask = Task { - try await client.ping() - } - - // Give it a moment to send the request - try await Task.sleep(for: .milliseconds(10)) - - // Verify the request was sent - #expect(await transport.sentMessages.count == 1) - - // Simulate a network error during send + // Set up the transport to fail sends from the start await transport.setFailSend(true) - // Create a response for the request + // Create a ping request to get the ID + let request = Ping.request() + + // Create a response for the request and queue it immediately let response = Response(id: request.id, result: .init()) let anyResponse = try AnyResponse(response) - - // Queue the response try await transport.queue(response: anyResponse) - // Wait for the send task to complete + // Now attempt to send the request - this should fail due to send error + // but the response handler might also try to process the queued response do { - _ = try await sendTask.value + _ = try await client.ping() #expect(Bool(false), "Expected send to fail") } catch let error as MCPError { if case .transportError = error { @@ -529,7 +517,7 @@ struct ClientTests { // Verify no continuation misuse occurred // (If it did, the test would have crashed) -// await client.disconnect() + await client.disconnect() } @Test("Race condition between response and send error") @@ -540,37 +528,27 @@ struct ClientTests { try await client.connect(transport: transport) try await Task.sleep(for: .milliseconds(10)) - // Create a ping request + // Create a ping request to get the ID let request = Ping.request() - // Create a response for the request + // Create a response for the request and queue it immediately let response = Response(id: request.id, result: .init()) let anyResponse = try AnyResponse(response) - - // Queue the response before sending the request try await transport.queue(response: anyResponse) - // Create a task that will send the request - let sendTask = Task { - try await client.ping() - } - - // Give it a moment to send the request - try await Task.sleep(for: .milliseconds(10)) - - // Verify the request was sent - #expect(await transport.sentMessages.count == 1) - - // Simulate a network error during send + // Set up the transport to fail sends await transport.setFailSend(true) - // Wait for the send task to complete + // Now attempt to send the request + // The response might be processed before the send error occurs do { - _ = try await sendTask.value - #expect(Bool(false), "Expected send to fail") + _ = try await client.ping() + // In this case, the response handler won the race and the request succeeded + #expect(Bool(true), "Response handler won the race - request succeeded") } catch let error as MCPError { if case .transportError = error { - #expect(Bool(true)) + // In this case, the send error handler won the race + #expect(Bool(true), "Send error handler won the race - request failed") } else { #expect(Bool(false), "Expected transport error, got \(error)") }