diff --git a/Sources/MCP/Base/Error.swift b/Sources/MCP/Base/Error.swift index 6ce4e1a5..f1d4dcc3 100644 --- a/Sources/MCP/Base/Error.swift +++ b/Sources/MCP/Base/Error.swift @@ -170,25 +170,38 @@ extension MCPError: Codable { let message = try container.decode(String.self, forKey: .message) let data = try container.decodeIfPresent([String: Value].self, forKey: .data) + // Helper to extract detail from data, falling back to message if needed + let unwrapDetail: (String?) -> String? = { fallback in + guard let detailValue = data?["detail"] else { return fallback } + if case .string(let str) = detailValue { return str } + return fallback + } + switch code { case -32700: - self = .parseError(data?["detail"] as? String ?? message) + self = .parseError(unwrapDetail(message)) case -32600: - self = .invalidRequest(data?["detail"] as? String ?? message) + self = .invalidRequest(unwrapDetail(message)) case -32601: - self = .methodNotFound(data?["detail"] as? String ?? message) + self = .methodNotFound(unwrapDetail(message)) case -32602: - self = .invalidParams(data?["detail"] as? String ?? message) + self = .invalidParams(unwrapDetail(message)) case -32603: - self = .internalError(data?["detail"] as? String ?? message) + self = .internalError(unwrapDetail(nil)) case -32000: self = .connectionClosed case -32001: + // Extract underlying error string if present + let underlyingErrorString = + data?["error"].flatMap { val -> String? in + if case .string(let str) = val { return str } + return nil + } ?? message self = .transportError( NSError( domain: "org.jsonrpc.error", code: code, - userInfo: [NSLocalizedDescriptionKey: message] + userInfo: [NSLocalizedDescriptionKey: underlyingErrorString] ) ) default: diff --git a/Sources/MCP/Base/Messages.swift b/Sources/MCP/Base/Messages.swift index e5effeb5..29d8725f 100644 --- a/Sources/MCP/Base/Messages.swift +++ b/Sources/MCP/Base/Messages.swift @@ -30,7 +30,7 @@ public protocol Method { } /// Type-erased method for request/response handling -struct AnyMethod: Method { +struct AnyMethod: Method, Sendable { static var name: String { "" } typealias Parameters = Value typealias Result = Value @@ -139,9 +139,19 @@ extension Request { /// A type-erased request for request/response handling typealias AnyRequest = Request +extension AnyRequest { + init(_ request: Request) throws { + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + let data = try encoder.encode(request) + self = try decoder.decode(AnyRequest.self, from: data) + } +} + /// A box for request handlers that can be type-erased class RequestHandlerBox: @unchecked Sendable { - func callAsFunction(_ request: Request) async throws -> Response { + func callAsFunction(_ request: AnyRequest) async throws -> AnyResponse { fatalError("Must override") } } @@ -155,8 +165,7 @@ final class TypedRequestHandler: RequestHandlerBox, @unchecked Sendab super.init() } - override func callAsFunction(_ request: Request) async throws -> Response - { + override func callAsFunction(_ request: AnyRequest) async throws -> AnyResponse { let encoder = JSONEncoder() let decoder = JSONDecoder() @@ -238,10 +247,28 @@ public struct Response: Hashable, Identifiable, Codable, Sendable { /// A type-erased response for request/response handling typealias AnyResponse = Response +extension AnyResponse { + init(_ response: Response) throws { + // Instead of re-encoding/decoding which might double-wrap the error, + // directly transfer the properties + self.id = response.id + switch response.result { + case .success(let result): + // For success, we still need to convert the result to a Value + let data = try JSONEncoder().encode(result) + let resultValue = try JSONDecoder().decode(Value.self, from: data) + self.result = .success(resultValue) + case .failure(let error): + // Keep the original error without re-encoding/decoding + self.result = .failure(error) + } + } +} + // MARK: - /// A notification message. -public protocol Notification { +public protocol Notification: Hashable, Codable, Sendable { /// The parameters of the notification. associatedtype Parameters: Hashable, Codable, Sendable = Empty /// The name of the notification. @@ -249,11 +276,21 @@ public protocol Notification { } /// A type-erased notification for message handling -struct AnyNotification: Notification { +struct AnyNotification: Notification, Sendable { static var name: String { "" } typealias Parameters = Empty } +extension AnyNotification { + init(_ notification: some Notification) throws { + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + let data = try encoder.encode(notification) + self = try decoder.decode(AnyNotification.self, from: data) + } +} + /// A message that can be used to send notifications. public struct Message: Hashable, Codable, Sendable { /// The method name. diff --git a/Sources/MCP/Base/Transports/StdioTransport.swift b/Sources/MCP/Base/Transports/StdioTransport.swift index 7978a6dd..6b307e6f 100644 --- a/Sources/MCP/Base/Transports/StdioTransport.swift +++ b/Sources/MCP/Base/Transports/StdioTransport.swift @@ -17,6 +17,12 @@ import struct Foundation.Data #if canImport(Darwin) || canImport(Glibc) /// Standard input/output transport implementation + /// + /// This transport supports JSON-RPC 2.0 messages, including individual requests, + /// notifications, responses, and batches containing multiple requests/notifications. + /// + /// Messages are delimited by newlines and must not contain embedded newlines. + /// Each message must be a complete, valid JSON object or array (for batches). public actor StdioTransport: Transport { private let input: FileDescriptor private let output: FileDescriptor @@ -131,6 +137,13 @@ import struct Foundation.Data logger.info("Transport disconnected") } + /// Sends a message over the transport. + /// + /// This method supports sending both individual JSON-RPC messages and JSON-RPC batches. + /// Batches should be encoded as a JSON array containing multiple request/notification objects + /// according to the JSON-RPC 2.0 specification. + /// + /// - Parameter message: The message data to send (without a trailing newline) public func send(_ message: Data) async throws { guard isConnected else { throw MCPError.transportError(Errno(rawValue: ENOTCONN)) @@ -158,6 +171,11 @@ import struct Foundation.Data } } + /// Receives messages from the transport. + /// + /// Messages may be individual JSON-RPC requests, notifications, responses, + /// or batches containing multiple requests/notifications encoded as JSON arrays. + /// Each message is guaranteed to be a complete JSON object or array. public func receive() -> AsyncThrowingStream { return AsyncThrowingStream { continuation in Task { diff --git a/Sources/MCP/Client/Client.swift b/Sources/MCP/Client/Client.swift index 19c60297..fa4fbe09 100644 --- a/Sources/MCP/Client/Client.swift +++ b/Sources/MCP/Client/Client.swift @@ -153,6 +153,9 @@ public actor Client { /// A dictionary of type-erased pending requests, keyed by request ID private var pendingRequests: [ID: AnyPendingRequest] = [:] + // Add reusable JSON encoder/decoder + private let encoder = JSONEncoder() + private let decoder = JSONDecoder() public init( name: String, @@ -184,9 +187,11 @@ public actor Client { for try await data in stream { if Task.isCancelled { break } // Check inside loop too - // Attempt to decode data as AnyResponse or AnyMessage - let decoder = JSONDecoder() - if let response = try? decoder.decode(AnyResponse.self, from: data), + // Attempt to decode data + // Try decoding as a batch response first + if let batchResponse = try? decoder.decode([AnyResponse].self, from: data) { + await handleBatchResponse(batchResponse) + } else if let response = try? decoder.decode(AnyResponse.self, from: data), let request = pendingRequests[response.id] { await handleResponse(response, for: request) @@ -198,7 +203,9 @@ public actor Client { metadata["message"] = .string(string) } await logger?.warning( - "Unexpected message received by client", metadata: metadata) + "Unexpected message received by client (not single/batch response or notification)", + metadata: metadata + ) } } } catch let error where MCPError.isResourceTemporarilyUnavailable(error) { @@ -250,7 +257,8 @@ public actor Client { throw MCPError.internalError("Client connection not initialized") } - let requestData = try JSONEncoder().encode(request) + // Use the actor's encoder + let requestData = try encoder.encode(request) // Store the pending request first return try await withCheckedThrowingContinuation { continuation in @@ -263,10 +271,12 @@ public actor Client { // Send the request data do { + // Use the existing connection send try await connection.send(requestData) } catch { + // If send fails immediately, resume continuation and remove pending request continuation.resume(throwing: error) - self.removePendingRequest(id: request.id) + self.removePendingRequest(id: request.id) // Ensure cleanup on send error } } } @@ -275,7 +285,7 @@ public actor Client { private func addPendingRequest( id: ID, continuation: CheckedContinuation, - type: T.Type + type: T.Type // Keep type for AnyPendingRequest internal logic ) { pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation)) } @@ -284,6 +294,150 @@ public actor Client { pendingRequests.removeValue(forKey: id) } + // MARK: - Batching + + /// A batch of requests. + /// + /// Objects of this type are passed as an argument to the closure + /// of the ``Client/withBatch(_:)`` method. + public actor Batch { + unowned let client: Client + var requests: [AnyRequest] = [] + + init(client: Client) { + self.client = client + } + + /// Adds a request to the batch and prepares its expected response task. + /// The actual sending happens when the `withBatch` scope completes. + /// - Returns: A `Task` that will eventually produce the result or throw an error. + public func addRequest(_ request: Request) async throws -> Task< + M.Result, Swift.Error + > { + requests.append(try AnyRequest(request)) + + // Return a Task that registers the pending request and awaits its result. + // The continuation is resumed when the response arrives. + return Task { + try await withCheckedThrowingContinuation { continuation in + // We are already inside a Task, but need another Task + // to bridge to the client actor's context. + Task { + await client.addPendingRequest( + id: request.id, + continuation: continuation, + type: M.Result.self + ) + } + } + } + } + } + + /// Executes multiple requests in a single batch. + /// + /// This method allows you to group multiple MCP requests together, + /// which are then sent to the server as a single JSON array. + /// The server processes these requests and sends back a corresponding + /// JSON array of responses. + /// + /// Within the `body` closure, use the provided `Batch` actor to add + /// requests using `batch.addRequest(_:)`. Each call to `addRequest` + /// returns a `Task` handle representing the asynchronous operation + /// for that specific request's result. + /// + /// It's recommended to collect these `Task` handles into an array + /// within the `body` closure`. After the `withBatch` method returns + /// (meaning the batch request has been sent), you can then process + /// the results by awaiting each `Task` in the collected array. + /// + /// Example 1: Batching multiple tool calls and collecting typed tasks: + /// ```swift + /// // Array to hold the task handles for each tool call + /// var toolTasks: [Task] = [] + /// try await client.withBatch { batch in + /// for i in 0..<10 { + /// toolTasks.append( + /// try await batch.addRequest( + /// CallTool.request(.init(name: "square", arguments: ["n": i])) + /// ) + /// ) + /// } + /// } + /// + /// // Process results after the batch is sent + /// print("Processing \(toolTasks.count) tool results...") + /// for (index, task) in toolTasks.enumerated() { + /// do { + /// let result = try await task.value + /// print("\(index): \(result.content)") + /// } catch { + /// print("\(index) failed: \(error)") + /// } + /// } + /// ``` + /// + /// Example 2: Batching different request types and awaiting individual tasks: + /// ```swift + /// // Declare optional task variables beforehand + /// var pingTask: Task? + /// var promptTask: Task? + /// + /// try await client.withBatch { batch in + /// // Assign the tasks within the batch closure + /// pingTask = try await batch.addRequest(Ping.request()) + /// promptTask = try await batch.addRequest(GetPrompt.request(.init(name: "greeting"))) + /// } + /// + /// // Await the results after the batch is sent + /// do { + /// if let pingTask = pingTask { + /// try await pingTask.value // Await ping result (throws if ping failed) + /// print("Ping successful") + /// } + /// if let promptTask = promptTask { + /// let promptResult = try await promptTask.value // Await prompt result + /// print("Prompt description: \(promptResult.description ?? "None")") + /// } + /// } catch { + /// print("Error processing batch results: \(error)") + /// } + /// ``` + /// + /// - Parameter body: An asynchronous closure that takes a `Batch` object as input. + /// Use this object to add requests to the batch. + /// - Throws: `MCPError.internalError` if the client is not connected. + /// Can also rethrow errors from the `body` closure or from sending the batch request. + public func withBatch(body: @escaping (Batch) async throws -> Void) async throws { + guard let connection = connection else { + throw MCPError.internalError("Client connection not initialized") + } + + // Create Batch actor, passing self (Client) + let batch = Batch(client: self) + + // Populate the batch actor by calling the user's closure. + try await body(batch) + + // Get the collected requests from the batch actor + let requests = await batch.requests + + // Check if there are any requests to send + guard !requests.isEmpty else { + await logger?.info("Batch requested but no requests were added.") + return // Nothing to send + } + + await logger?.debug( + "Sending batch request", metadata: ["count": "\(requests.count)"]) + + // Encode the array of AnyMethod requests into a single JSON payload + let data = try encoder.encode(requests) + try await connection.send(data) + + // Responses will be handled asynchronously by the message loop and handleBatchResponse/handleResponse. + } + // MARK: - Lifecycle public func initialize() async throws -> Initialize.Result { @@ -364,7 +518,9 @@ public actor Client { // MARK: - Tools - public func listTools(cursor: String? = nil) async throws -> (tools: [Tool], nextCursor: String?) { + public func listTools(cursor: String? = nil) async throws -> ( + tools: [Tool], nextCursor: String? + ) { try validateServerCapability(\.tools, "Tools") let request: Request if let cursor = cursor { @@ -446,4 +602,22 @@ public actor Client { } } } + + // Add handler for batch responses + private func handleBatchResponse(_ responses: [AnyResponse]) async { + await logger?.debug("Processing batch response", metadata: ["count": "\(responses.count)"]) + for response in responses { + // Look up the pending request for this specific ID within the batch + if let request = pendingRequests[response.id] { + // Reuse the existing single response handler logic + await handleResponse(response, for: request) + } else { + // Log if a response ID doesn't match any pending request + await logger?.warning( + "Received response in batch for unknown request ID", + metadata: ["id": "\(response.id)"] + ) + } + } + } } diff --git a/Sources/MCP/Server/Server.swift b/Sources/MCP/Server/Server.swift index ef50e2c5..9049d75a 100644 --- a/Sources/MCP/Server/Server.swift +++ b/Sources/MCP/Server/Server.swift @@ -177,10 +177,12 @@ public actor Server { var requestID: ID? do { - // Attempt to decode string data as AnyRequest or AnyMessage + // Attempt to decode as batch first, then as individual request or notification let decoder = JSONDecoder() - if let request = try? decoder.decode(AnyRequest.self, from: data) { - try await handleRequest(request) + if let batch = try? decoder.decode(Server.Batch.self, from: data) { + try await handleBatch(batch) + } else if let request = try? decoder.decode(AnyRequest.self, from: data) { + _ = try await handleRequest(request, sendResponse: true) } else if let message = try? decoder.decode(AnyMessage.self, from: data) { try await handleMessage(message) } else { @@ -288,9 +290,91 @@ public actor Server { try await connection.send(notificationData) } - // MARK: - + /// A JSON-RPC batch containing multiple requests and/or notifications + struct Batch: Sendable { + /// An item in a JSON-RPC batch + enum Item: Sendable { + case request(Request) + case notification(Message) + + } + + var items: [Item] + + init(items: [Item]) { + self.items = items + } + } + + /// Process a batch of requests and/or notifications + private func handleBatch(_ batch: Batch) async throws { + await logger?.debug("Processing batch request", metadata: ["size": "\(batch.items.count)"]) + + if batch.items.isEmpty { + // Empty batch is invalid according to JSON-RPC spec + let error = MCPError.invalidRequest("Batch array must not be empty") + let response = AnyMethod.response(id: .random, error: error) + try await send(response) + return + } + + // Process each item in the batch and collect responses + var responses: [Response] = [] + + for item in batch.items { + do { + switch item { + case .request(let request): + // For batched requests, collect responses instead of sending immediately + if let response = try await handleRequest(request, sendResponse: false) { + responses.append(response) + } + + case .notification(let notification): + // Handle notification (no response needed) + try await handleMessage(notification) + } + } catch { + // Only add errors to response for requests (notifications don't have responses) + if case .request(let request) = item { + let mcpError = error as? MCPError ?? MCPError.internalError(error.localizedDescription) + responses.append(AnyMethod.response(id: request.id, error: mcpError)) + } + } + } + + // Send collected responses if any + if !responses.isEmpty { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] + let responseData = try encoder.encode(responses) + + guard let connection = connection else { + throw MCPError.internalError("Server connection not initialized") + } + + try await connection.send(responseData) + } + } + + // MARK: - Request and Message Handling + + /// Handle a request and either send the response immediately or return it + /// + /// - Parameters: + /// - request: The request to handle + /// - sendResponse: Whether to send the response immediately (true) or return it (false) + /// - Returns: The response when sendResponse is false + private func handleRequest(_ request: Request, sendResponse: Bool = true) async throws -> Response? { + // Check if this is a pre-processed error request (empty method) + if request.method.isEmpty && !sendResponse { + // This is a placeholder for an invalid request that couldn't be parsed in batch mode + return AnyMethod.response( + id: request.id, + error: MCPError.invalidRequest("Invalid batch item format") + ) + } - private func handleRequest(_ request: Request) async throws { await logger?.debug( "Processing request", metadata: [ @@ -313,19 +397,35 @@ public actor Server { guard let handler = methodHandlers[request.method] else { let error = MCPError.methodNotFound("Unknown method: \(request.method)") let response = AnyMethod.response(id: request.id, error: error) - try await send(response) - throw error + + if sendResponse { + try await send(response) + return nil + } + + return response } do { // Handle request and get response let response = try await handler(request) - try await send(response) + + if sendResponse { + try await send(response) + return nil + } + + return response } catch { let mcpError = error as? MCPError ?? MCPError.internalError(error.localizedDescription) let response = AnyMethod.response(id: request.id, error: mcpError) - try await send(response) - throw error + + if sendResponse { + try await send(response) + return nil + } + + return response } } @@ -425,3 +525,50 @@ public actor Server { self.isInitialized = true } } + +extension Server.Batch: Codable { + init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + + let encoder = JSONEncoder() + let decoder = JSONDecoder() + + var items: [Item] = [] + for item in try container.decode([Value].self) { + let data = try encoder.encode(item) + try items.append(decoder.decode(Item.self, from: data)) + } + + self.items = items + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(items) + } +} + +extension Server.Batch.Item: Codable { + private enum CodingKeys: String, CodingKey { + case id + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + // Check if it's a request (has id) or notification (no id) + if container.contains(.id) { + self = .request(try Request(from: decoder)) + } else { + self = .notification(try Message(from: decoder)) + } + } + + func encode(to encoder: Encoder) throws { + switch self { + case .request(let request): + try request.encode(to: encoder) + case .notification(let notification): + try notification.encode(to: encoder) + } + } +} \ No newline at end of file diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index 502bc728..e7e2c023 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -226,4 +226,141 @@ struct ClientTests { // Disconnect client await client.disconnect() } + + @Test("Batch request - success") + func testBatchRequestSuccess() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) // Allow connection tasks + + let request1 = Ping.request() + let request2 = Ping.request() + var resultTask1: Task? + var resultTask2: Task? + + try await client.withBatch { batch in + resultTask1 = try await batch.addRequest(request1) + resultTask2 = try await batch.addRequest(request2) + } + + // Check if one batch message was sent + let sentMessages = await transport.sentMessages + #expect(sentMessages.count == 1) + + guard let batchData = sentMessages.first?.data(using: .utf8) else { + #expect(Bool(false), "Failed to get batch data") + return + } + + // Verify the sent batch contains the two requests + let decoder = JSONDecoder() + let sentRequests = try decoder.decode([AnyRequest].self, from: batchData) + #expect(sentRequests.count == 2) + #expect(sentRequests.first?.id == request1.id) + #expect(sentRequests.first?.method == Ping.name) + #expect(sentRequests.last?.id == request2.id) + #expect(sentRequests.last?.method == Ping.name) + + // Prepare batch response + let response1 = Response(id: request1.id, result: .init()) + let response2 = Response(id: request2.id, result: .init()) + let anyResponse1 = try AnyResponse(response1) + let anyResponse2 = try AnyResponse(response2) + + // Queue the batch response + try await transport.queue(batch: [anyResponse1, anyResponse2]) + + // Wait for results and verify + guard let task1 = resultTask1, let task2 = resultTask2 else { + #expect(Bool(false), "Result tasks not created") + return + } + + _ = try await task1.value // Should succeed + _ = try await task2.value // Should succeed + + #expect(Bool(true)) // Reaching here means success + + await client.disconnect() + } + + @Test("Batch request - mixed success/error") + func testBatchRequestMixed() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) + + let request1 = Ping.request() // Success + let request2 = Ping.request() // Error + + var resultTasks: [Task] = [] + + try await client.withBatch { batch in + resultTasks.append(try await batch.addRequest(request1)) + resultTasks.append(try await batch.addRequest(request2)) + } + + // Check if one batch message was sent + #expect(await transport.sentMessages.count == 1) + + // Prepare batch response (success for 1, error for 2) + let response1 = Response(id: request1.id, result: .init()) + let error = MCPError.internalError("Simulated batch error") + let response2 = Response(id: request2.id, error: error) + let anyResponse1 = try AnyResponse(response1) + let anyResponse2 = try AnyResponse(response2) + + // Queue the batch response + try await transport.queue(batch: [anyResponse1, anyResponse2]) + + // Wait for results and verify + #expect(resultTasks.count == 2) + guard resultTasks.count == 2 else { + #expect(Bool(false), "Expected 2 result tasks") + return + } + + let task1 = resultTasks[0] + let task2 = resultTasks[1] + + _ = try await task1.value // Task 1 should succeed + + do { + _ = try await task2.value // Task 2 should fail + #expect(Bool(false), "Task 2 should have thrown an error") + } catch let mcpError as MCPError { + if case .internalError(let message) = mcpError { + #expect(message == "Simulated batch error") + } else { + #expect(Bool(false), "Expected internalError, got \(mcpError)") + } + } catch { + #expect(Bool(false), "Expected MCPError, got \(error)") + } + + await client.disconnect() + } + + @Test("Batch request - empty") + func testBatchRequestEmpty() async throws { + let transport = MockTransport() + let client = Client(name: "TestClient", version: "1.0") + + try await client.connect(transport: transport) + try await Task.sleep(for: .milliseconds(10)) + + // Call withBatch but don't add any requests + try await client.withBatch { _ in + // No requests added + } + + // Check that no messages were sent + #expect(await transport.sentMessages.isEmpty) + + await client.disconnect() + } } diff --git a/Tests/MCPTests/Helpers/MockTransport.swift b/Tests/MCPTests/Helpers/MockTransport.swift index bcd8dd9d..5979e9ad 100644 --- a/Tests/MCPTests/Helpers/MockTransport.swift +++ b/Tests/MCPTests/Helpers/MockTransport.swift @@ -80,8 +80,7 @@ actor MockTransport: Transport { shouldFailSend = shouldFail } - func queue(request: Request) throws { - let data = try encoder.encode(request) + func queue(data: Data) { if let continuation = dataStreamContinuation { continuation.yield(data) } else { @@ -89,14 +88,24 @@ actor MockTransport: Transport { } } + func queue(request: Request) throws { + queue(data: try encoder.encode(request)) + } + func queue(response: Response) throws { - let data = try encoder.encode(response) - dataToReceive.append(data) + queue(data: try encoder.encode(response)) } func queue(notification: Message) throws { - let data = try encoder.encode(notification) - dataToReceive.append(data) + queue(data: try encoder.encode(notification)) + } + + func queue(batch requests: [AnyRequest]) throws { + queue(data: try encoder.encode(requests)) + } + + func queue(batch responses: [AnyResponse]) throws { + queue(data: try encoder.encode(responses)) } func decodeLastSentMessage() -> T? { diff --git a/Tests/MCPTests/ResponseTests.swift b/Tests/MCPTests/ResponseTests.swift index ff82e64b..cb918521 100644 --- a/Tests/MCPTests/ResponseTests.swift +++ b/Tests/MCPTests/ResponseTests.swift @@ -78,7 +78,7 @@ struct ResponseTests { #expect(decodedError.code == -32700) #expect( decodedError.localizedDescription - == "Parse error: Invalid JSON: Parse error: Invalid JSON: Invalid syntax") + == "Parse error: Invalid JSON: Invalid syntax") } else { #expect(Bool(false), "Expected error result") } diff --git a/Tests/MCPTests/ServerTests.swift b/Tests/MCPTests/ServerTests.swift index c7db67f9..91b97bd7 100644 --- a/Tests/MCPTests/ServerTests.swift +++ b/Tests/MCPTests/ServerTests.swift @@ -144,4 +144,62 @@ struct ServerTests { await server.stop() await transport.disconnect() } + + @Test("JSON-RPC batch processing") + func testJSONRPCBatchProcessing() async throws { + let transport = MockTransport() + let server = Server(name: "TestServer", version: "1.0") + + // Start the server + try await server.start(transport: transport) + + // Initialize the server first + try await transport.queue( + request: Initialize.request( + .init( + protocolVersion: Version.latest, + capabilities: .init(), + clientInfo: .init(name: "TestClient", version: "1.0") + ) + ) + ) + + // Wait for server to initialize and respond + try await Task.sleep(for: .milliseconds(100)) + + // Clear sent messages + await transport.clearMessages() + + // Create a batch with multiple requests + let batchJSON = """ + [ + {"jsonrpc":"2.0","id":1,"method":"ping","params":{}}, + {"jsonrpc":"2.0","id":2,"method":"ping","params":{}} + ] + """ + let batch = try JSONDecoder().decode([AnyRequest].self, from: batchJSON.data(using: .utf8)!) + + // Send the batch request + try await transport.queue(batch: batch) + + // Wait for batch processing + try await Task.sleep(for: .milliseconds(100)) + + // Verify response + let sentMessages = await transport.sentMessages + #expect(sentMessages.count == 1) + + if let batchResponse = sentMessages.first { + // Should be an array + #expect(batchResponse.hasPrefix("[")) + #expect(batchResponse.hasSuffix("]")) + + // Should contain both request IDs + #expect(batchResponse.contains("\"id\":1")) + #expect(batchResponse.contains("\"id\":2")) + } + + await server.stop() + await transport.disconnect() + } }