diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index f6ec77a3..f45b4b96 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -191,10 +191,8 @@ public actor Client { // Try decoding as a batch response first if let batchResponse = try? decoder.decode([AnyResponse].self, from: data) { await handleBatchResponse(batchResponse) - } else if let response = try? decoder.decode(AnyResponse.self, from: data), - let request = pendingRequests[response.id] - { - await handleResponse(response, for: request) + } else if let response = try? decoder.decode(AnyResponse.self, from: data) { + await handleResponse(response) } else if let message = try? decoder.decode(AnyMessage.self, from: data) { await handleMessage(message) } else { @@ -217,23 +215,49 @@ public actor Client { break } } while true + await self.logger?.info("Client message handling loop task is terminating.") } } /// Disconnect the client and cancel all pending requests public func disconnect() async { - // Cancel all pending requests - for (id, request) in pendingRequests { + await logger?.info("Initiating client disconnect...") + + // Part 1: Inside actor - Grab state and clear internal references + let taskToCancel = self.task + let connectionToDisconnect = self.connection + let pendingRequestsToCancel = self.pendingRequests + + self.task = nil + self.connection = nil + self.pendingRequests = [:] // Use empty dictionary literal + + // Part 2: Outside actor - Resume continuations, disconnect transport, await task + + // Resume continuations first + for (_, request) in pendingRequestsToCancel { request.resume(throwing: MCPError.internalError("Client disconnected")) - pendingRequests.removeValue(forKey: id) } + await logger?.info("Pending requests cancelled.") - task?.cancel() - task = nil - if let connection = connection { - await connection.disconnect() + // Cancel the task + taskToCancel?.cancel() + await logger?.info("Message loop task cancellation requested.") + + // Disconnect the transport *before* awaiting the task + // This should ensure the transport stream is finished, unblocking the loop. + if let conn = connectionToDisconnect { + await conn.disconnect() + await logger?.info("Transport disconnected.") + } else { + await logger?.info("No active transport connection to disconnect.") } - connection = nil + + // Await the task completion *after* transport disconnect + _ = await taskToCancel?.value + await logger?.info("Client message loop task finished.") + + await logger?.info("Client disconnect complete.") } // MARK: - Registration @@ -267,12 +291,12 @@ public actor Client { throw MCPError.internalError("Client connection not initialized") } - // Use the actor's encoder let requestData = try encoder.encode(request) // Store the pending request first return try await withCheckedThrowingContinuation { continuation in Task { + // Add the pending request before attempting to send self.addPendingRequest( id: request.id, continuation: continuation, @@ -284,9 +308,15 @@ public actor Client { // Use the existing connection send try await connection.send(requestData) } catch { - // If send fails immediately, resume continuation and remove pending request - continuation.resume(throwing: error) - self.removePendingRequest(id: request.id) // Ensure cleanup on send error + // If send fails, try to remove the pending request. + // Resume with the send error only if we successfully removed the request, + // indicating the response handler hasn't processed it yet. + if self.removePendingRequest(id: request.id) != nil { + continuation.resume(throwing: error) + } + // Otherwise, the request was already removed by the response handler + // or by disconnect, so the continuation was already resumed. + // Do nothing here. } } } @@ -300,8 +330,8 @@ public actor Client { pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation)) } - private func removePendingRequest(id: ID) { - pendingRequests.removeValue(forKey: id) + private func removePendingRequest(id: ID) -> AnyPendingRequest? { + return pendingRequests.removeValue(forKey: id) } // MARK: - Batching @@ -555,21 +585,29 @@ public actor Client { // MARK: - - private func handleResponse(_ response: Response, for request: AnyPendingRequest) - async - { + private func handleResponse(_ response: Response) async { await logger?.debug( "Processing response", metadata: ["id": "\(response.id)"]) - switch response.result { - case .success(let value): - request.resume(returning: value) - case .failure(let error): - request.resume(throwing: error) + // Attempt to remove the pending request using the response ID. + // Resume with the response only if it hadn't yet been removed. + if let removedRequest = self.removePendingRequest(id: response.id) { + // If we successfully removed it, resume its continuation. + switch response.result { + case .success(let value): + removedRequest.resume(returning: value) + case .failure(let error): + removedRequest.resume(throwing: error) + } + } else { + // Request was already removed (e.g., by send error handler or disconnect). + // Log this, but it's not an error in race condition scenarios. + await logger?.warning( + "Attempted to handle response for already removed request", + metadata: ["id": "\(response.id)"] + ) } - - removePendingRequest(id: response.id) } private func handleMessage(_ message: Message) async { @@ -619,14 +657,21 @@ public actor Client { private func handleBatchResponse(_ responses: [AnyResponse]) async { await logger?.debug("Processing batch response", metadata: ["count": "\(responses.count)"]) for response in responses { - // Look up the pending request for this specific ID within the batch - if let request = pendingRequests[response.id] { - // Reuse the existing single response handler logic - await handleResponse(response, for: request) + // Attempt to remove the pending request. + // If successful, pendingRequest contains the request. + if let pendingRequest = self.removePendingRequest(id: response.id) { + // If we successfully removed it, handle the response using the pending request. + switch response.result { + case .success(let value): + pendingRequest.resume(returning: value) + case .failure(let error): + pendingRequest.resume(throwing: error) + } } else { - // Log if a response ID doesn't match any pending request + // If removal failed, it means the request ID was not found (or already handled). + // Log a warning. await logger?.warning( - "Received response in batch for unknown request ID", + "Received response in batch for unknown or already handled request ID", metadata: ["id": "\(response.id)"] ) } diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index cfbedeb1..9e649250 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -479,4 +479,86 @@ struct ClientTests { await client.disconnect() } + + @Test("Race condition between send error and response") + func testSendErrorResponseRace() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) + + // Set up the transport to fail sends from the start + await transport.setFailSend(true) + + // Create a ping request to get the ID + let request = Ping.request() + + // Create a response for the request and queue it immediately + let response = Response(id: request.id, result: .init()) + let anyResponse = try AnyResponse(response) + try await transport.queue(response: anyResponse) + + // Now attempt to send the request - this should fail due to send error + // but the response handler might also try to process the queued response + do { + _ = try await client.ping() + #expect(Bool(false), "Expected send to fail") + } catch let error as MCPError { + if case .transportError = error { + #expect(Bool(true)) + } else { + #expect(Bool(false), "Expected transport error, got \(error)") + } + } catch { + #expect(Bool(false), "Expected MCPError, got \(error)") + } + + // Verify no continuation misuse occurred + // (If it did, the test would have crashed) + + await client.disconnect() + } + + @Test("Race condition between response and send error") + func testResponseSendErrorRace() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) + + // Create a ping request to get the ID + let request = Ping.request() + + // Create a response for the request and queue it immediately + let response = Response(id: request.id, result: .init()) + let anyResponse = try AnyResponse(response) + try await transport.queue(response: anyResponse) + + // Set up the transport to fail sends + await transport.setFailSend(true) + + // Now attempt to send the request + // The response might be processed before the send error occurs + do { + _ = try await client.ping() + // In this case, the response handler won the race and the request succeeded + #expect(Bool(true), "Response handler won the race - request succeeded") + } catch let error as MCPError { + if case .transportError = error { + // In this case, the send error handler won the race + #expect(Bool(true), "Send error handler won the race - request failed") + } else { + #expect(Bool(false), "Expected transport error, got \(error)") + } + } catch { + #expect(Bool(false), "Expected MCPError, got \(error)") + } + + // Verify no continuation misuse occurred + // (If it did, the test would have crashed) + + await client.disconnect() + } }