diff --git a/Sources/MCP/Base/Transports.swift b/Sources/MCP/Base/Transports.swift index 1cc050c4..bd439321 100644 --- a/Sources/MCP/Base/Transports.swift +++ b/Sources/MCP/Base/Transports.swift @@ -19,11 +19,11 @@ public protocol Transport: Actor { /// Disconnects from the transport func disconnect() async - /// Sends a message string - func send(_ message: String) async throws + /// Sends data + func send(_ data: Data) async throws - /// Receives message strings as an async sequence - func receive() -> AsyncThrowingStream + /// Receives data in an async sequence + func receive() -> AsyncThrowingStream } /// Standard input/output transport implementation @@ -33,8 +33,8 @@ public actor StdioTransport: Transport { public nonisolated let logger: Logger private var isConnected = false - private let messageStream: AsyncStream - private let messageContinuation: AsyncStream.Continuation + private let messageStream: AsyncStream + private let messageContinuation: AsyncStream.Continuation public init( input: FileDescriptor = FileDescriptor.standardInput, @@ -50,7 +50,7 @@ public actor StdioTransport: Transport { factory: { _ in SwiftLogNoOpLogHandler() }) // Create message stream - var continuation: AsyncStream.Continuation! + var continuation: AsyncStream.Continuation! self.messageStream = AsyncStream { continuation = $0 } self.messageContinuation = continuation } @@ -105,15 +105,13 @@ public actor StdioTransport: Transport { let messageData = pendingData[.. AsyncThrowingStream { + public func receive() -> AsyncThrowingStream { return AsyncThrowingStream { continuation in Task { for await message in messageStream { @@ -182,8 +179,8 @@ public actor StdioTransport: Transport { public nonisolated let logger: Logger private var isConnected = false - private let messageStream: AsyncThrowingStream - private let messageContinuation: AsyncThrowingStream.Continuation + private let messageStream: AsyncThrowingStream + private let messageContinuation: AsyncThrowingStream.Continuation // Track connection state for continuations private var connectionContinuationResumed = false @@ -198,7 +195,7 @@ public actor StdioTransport: Transport { ) // Create message stream - var continuation: AsyncThrowingStream.Continuation! + var continuation: AsyncThrowingStream.Continuation! self.messageStream = AsyncThrowingStream { continuation = $0 } self.messageContinuation = continuation } @@ -289,14 +286,14 @@ public actor StdioTransport: Transport { logger.info("Network transport disconnected") } - public func send(_ message: String) async throws { + public func send(_ message: Data) async throws { guard isConnected else { throw MCP.Error.internalError("Transport not connected") } - guard let data = (message + "\n").data(using: .utf8) else { - throw MCP.Error.internalError("Failed to encode message") - } + // Add newline as delimiter + var messageWithNewline = message + messageWithNewline.append(UInt8(ascii: "\n")) // Use a local actor-isolated variable to track continuation state var sendContinuationResumed = false @@ -309,7 +306,7 @@ public actor StdioTransport: Transport { } connection.send( - content: data, + content: messageWithNewline, completion: .contentProcessed { [weak self] error in guard let self = self else { return } @@ -329,7 +326,7 @@ public actor StdioTransport: Transport { } } - public func receive() -> AsyncThrowingStream { + public func receive() -> AsyncThrowingStream { return AsyncThrowingStream { continuation in Task { do { @@ -357,11 +354,10 @@ public actor StdioTransport: Transport { let messageData = buffer[..(_ response: Response) async throws { guard let connection = connection else { throw Error.internalError("Server connection not initialized") } + let encoder = JSONEncoder() encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] let responseData = try encoder.encode(response) - - if let responseStr = String(data: responseData, encoding: .utf8) { - try await connection.send(responseStr) - } + try await connection.send(responseData) } /// Send a notification to connected clients @@ -291,10 +285,7 @@ public actor Server { encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes] let notificationData = try encoder.encode(notification) - - if let notificationStr = String(data: notificationData, encoding: .utf8) { - try await connection.send(notificationStr) - } + try await connection.send(notificationData) } // MARK: - @@ -407,7 +398,7 @@ public actor Server { // Send initialized notification after a short delay Task { - try? await Task.sleep(nanoseconds: 100_000_000) // 100ms + try? await Task.sleep(for: .milliseconds(10)) try? await self.notify(InitializedNotification.message()) } diff --git a/Tests/MCPTests/ClientTests.swift b/Tests/MCPTests/ClientTests.swift index d34c6018..ebbf0261 100644 --- a/Tests/MCPTests/ClientTests.swift +++ b/Tests/MCPTests/ClientTests.swift @@ -27,7 +27,7 @@ struct ClientTests { try await client.connect(transport: transport) // Small delay to ensure message loop is started - try await Task.sleep(nanoseconds: 10_000_000) // 10ms + try await Task.sleep(for: .milliseconds(10)) // Create a task for initialize that we'll cancel let initTask = Task { @@ -35,19 +35,19 @@ struct ClientTests { } // Give it a moment to send the request - try await Task.sleep(nanoseconds: 10_000_000) // 10ms + try await Task.sleep(for: .milliseconds(10)) #expect(await transport.sentMessages.count == 1) - #expect(await transport.sentMessages[0].contains(Initialize.name)) - #expect(await transport.sentMessages[0].contains(client.name)) - #expect(await transport.sentMessages[0].contains(client.version)) + #expect(await transport.sentMessages.first?.contains(Initialize.name) == true) + #expect(await transport.sentMessages.first?.contains(client.name) == true) + #expect(await transport.sentMessages.first?.contains(client.version) == true) // Cancel the initialize task initTask.cancel() // Disconnect client to clean up message loop and give time for continuation cleanup await client.disconnect() - try await Task.sleep(nanoseconds: 50_000_000) // 50ms + try await Task.sleep(for: .milliseconds(50)) } @Test( @@ -60,7 +60,7 @@ struct ClientTests { try await client.connect(transport: transport) // Small delay to ensure message loop is started - try await Task.sleep(nanoseconds: 10_000_000) // 10ms + try await Task.sleep(for: .milliseconds(10)) // Create a task for the ping that we'll cancel let pingTask = Task { @@ -68,17 +68,17 @@ struct ClientTests { } // Give it a moment to send the request - try await Task.sleep(nanoseconds: 10_000_000) // 10ms + try await Task.sleep(for: .milliseconds(10)) #expect(await transport.sentMessages.count == 1) - #expect(await transport.sentMessages[0].contains(Ping.name)) + #expect(await transport.sentMessages.first?.contains(Ping.name) == true) // Cancel the ping task pingTask.cancel() // Disconnect client to clean up message loop and give time for continuation cleanup await client.disconnect() - try await Task.sleep(nanoseconds: 50_000_000) // 50ms + try await Task.sleep(for: .milliseconds(50)) } @Test("Connection failure handling") @@ -168,7 +168,7 @@ struct ClientTests { // Wait a bit for any setup to complete try await Task.sleep(for: .milliseconds(10)) - + // Send the listPrompts request and immediately provide an error response let promptsTask = Task { do { @@ -187,7 +187,7 @@ struct ClientTests { id: decodedRequest.id, error: Error.methodNotFound("Test: Prompts capability not available") ) - try await transport.queueResponse(errorResponse) + try await transport.queue(response: errorResponse) // Try the request now that we have a response queued do { diff --git a/Tests/MCPTests/Helpers/MockTransport.swift b/Tests/MCPTests/Helpers/MockTransport.swift index 3abb57ff..571715ef 100644 --- a/Tests/MCPTests/Helpers/MockTransport.swift +++ b/Tests/MCPTests/Helpers/MockTransport.swift @@ -6,10 +6,28 @@ import Logging /// Mock transport for testing actor MockTransport: Transport { var logger: Logger + + let encoder = JSONEncoder() + let decoder = JSONDecoder() + var isConnected = false - private(set) var sentMessages: [String] = [] - private var messagesToReceive: [String] = [] - private var messageStreamContinuation: AsyncThrowingStream.Continuation? + + private(set) var sentData: [Data] = [] + var sentMessages: [String] { + return sentData.compactMap { data in + guard let string = String(data: data, encoding: .utf8) else { + logger.error("Failed to decode sent data as UTF-8") + return nil + } + return string + } + } + + private var dataToReceive: [Data] = [] + private(set) var receivedMessages: [String] = [] + + private var dataStreamContinuation: AsyncThrowingStream.Continuation? + var shouldFailConnect = false var shouldFailSend = false @@ -17,90 +35,77 @@ actor MockTransport: Transport { self.logger = logger } - func connect() async throws { + public func connect() async throws { if shouldFailConnect { throw Error.transportError(POSIXError(.ECONNREFUSED)) } isConnected = true } - func disconnect() async { + public func disconnect() async { isConnected = false - messageStreamContinuation?.finish() - messageStreamContinuation = nil + dataStreamContinuation?.finish() + dataStreamContinuation = nil } - func send(_ message: T) async throws { + public func send(_ message: Data) async throws { if shouldFailSend { throw Error.transportError(POSIXError(.EIO)) } - let data = try JSONEncoder().encode(message) - let str = String(data: data, encoding: .utf8)! - sentMessages.append(str) + sentData.append(message) } - func receive() -> AsyncThrowingStream { - return AsyncThrowingStream { continuation in - messageStreamContinuation = continuation - // Send any queued messages - for message in messagesToReceive { + public func receive() -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + dataStreamContinuation = continuation + for message in dataToReceive { continuation.yield(message) + if let string = String(data: message, encoding: .utf8) { + receivedMessages.append(string) + } } - messagesToReceive.removeAll() + dataToReceive.removeAll() } } - func queueRequest(_ request: Request) throws { - let data = try JSONEncoder().encode(request) - let str = String(data: data, encoding: .utf8)! - if let continuation = messageStreamContinuation { - continuation.yield(str) - } else { - sentMessages.append(str) - } + func setFailConnect(_ shouldFail: Bool) { + shouldFailConnect = shouldFail } - func queueResponse(_ response: Response) throws { - let data = try JSONEncoder().encode(response) - let str = String(data: data, encoding: .utf8)! - if let continuation = messageStreamContinuation { - continuation.yield(str) - } else { - messagesToReceive.append(str) - } + func setFailSend(_ shouldFail: Bool) { + shouldFailSend = shouldFail } - func queueNotification(_ notification: Message) throws { - let data = try JSONEncoder().encode(notification) - let str = String(data: data, encoding: .utf8)! - if let continuation = messageStreamContinuation { - continuation.yield(str) + func queue(request: Request) throws { + let data = try encoder.encode(request) + if let continuation = dataStreamContinuation { + continuation.yield(data) } else { - messagesToReceive.append(str) + dataToReceive.append(data) } } - func getLastSentMessage() -> T? { - print("SENT:", sentMessages) - guard let lastMessage = sentMessages.last else { return nil } + func queue(response: Response) throws { + let data = try encoder.encode(response) + dataToReceive.append(data) + } + + func queue(notification: Message) throws { + let data = try encoder.encode(notification) + dataToReceive.append(data) + } + + func decodeLastSentMessage() -> T? { + guard let lastMessage = sentData.last else { return nil } do { - let data = lastMessage.data(using: .utf8)! - return try JSONDecoder().decode(T.self, from: data) + return try decoder.decode(T.self, from: lastMessage) } catch { return nil } } func clearMessages() { - sentMessages.removeAll() - messagesToReceive.removeAll() - } - - func setFailConnect(_ shouldFail: Bool) { - shouldFailConnect = shouldFail - } - - func setFailSend(_ shouldFail: Bool) { - shouldFailSend = shouldFail + sentData.removeAll() + dataToReceive.removeAll() } } diff --git a/Tests/MCPTests/ServerTests.swift b/Tests/MCPTests/ServerTests.swift index 97e02601..f55706db 100644 --- a/Tests/MCPTests/ServerTests.swift +++ b/Tests/MCPTests/ServerTests.swift @@ -22,8 +22,8 @@ struct ServerTests { let transport = MockTransport() // Queue an initialize request - try await transport.queueRequest( - Initialize.request( + try await transport.queue( + request: Initialize.request( .init( protocolVersion: Version.latest, capabilities: .init(), @@ -42,7 +42,11 @@ struct ServerTests { try await Task.sleep(nanoseconds: 10_000_000) // 10ms #expect(await transport.sentMessages.count == 1) - #expect(await transport.sentMessages[0].contains(Initialize.name)) + + let messages = await transport.sentMessages + if let response = messages.first { + #expect(response.contains("serverInfo")) + } // Clean up await server.stop() @@ -68,13 +72,13 @@ struct ServerTests { #expect(clientInfo.version == "1.0") await state.setHookCalled() } - + // Wait for server to initialize - try await Task.sleep(nanoseconds: 10_000_000) // 10ms + try await Task.sleep(for: .milliseconds(10)) // Queue an initialize request - try await transport.queueRequest( - Initialize.request( + try await transport.queue( + request: Initialize.request( .init( protocolVersion: Version.latest, capabilities: .init(), @@ -83,7 +87,7 @@ struct ServerTests { )) // Wait for message processing and hook execution - try await Task.sleep(nanoseconds: 200_000_000) // 200ms + try await Task.sleep(for: .milliseconds(500)) #expect(await state.wasHookCalled() == true) #expect(await transport.sentMessages.count >= 1) @@ -94,6 +98,7 @@ struct ServerTests { } await server.stop() + await transport.disconnect() } @Test("Initialize hook - rejection") @@ -107,13 +112,13 @@ struct ServerTests { throw Error.invalidRequest("Client not allowed") } } - + // Wait for server to initialize try await Task.sleep(nanoseconds: 10_000_000) // 10ms // Queue an initialize request from blocked client - try await transport.queueRequest( - Initialize.request( + try await transport.queue( + request: Initialize.request( .init( protocolVersion: Version.latest, capabilities: .init(), @@ -124,13 +129,15 @@ struct ServerTests { // Wait for message processing try await Task.sleep(nanoseconds: 200_000_000) // 200ms - #expect(await transport.sentMessages.count >= 2) + #expect(await transport.sentMessages.count >= 1) let messages = await transport.sentMessages if let response = messages.first { #expect(response.contains("error")) #expect(response.contains("Client not allowed")) } + await server.stop() + await transport.disconnect() } } diff --git a/Tests/MCPTests/TransportTests.swift b/Tests/MCPTests/TransportTests.swift index 5a2b02ab..c0823da3 100644 --- a/Tests/MCPTests/TransportTests.swift +++ b/Tests/MCPTests/TransportTests.swift @@ -29,7 +29,7 @@ struct StdioTransportTests { // Test sending a simple message let message = #"{"key":"value"}"# - try await transport.send(message) + try await transport.send(message.data(using: .utf8)!) // Read and verify the output var buffer = [UInt8](repeating: 0, count: 1024) @@ -57,12 +57,12 @@ struct StdioTransportTests { try writer.close() // Start receiving messages - let stream: AsyncThrowingStream = await transport.receive() + let stream: AsyncThrowingStream = await transport.receive() var iterator = stream.makeAsyncIterator() // Get first message let received = try await iterator.next() - #expect(received == #"{"key":"value"}"#) + #expect(received == #"{"key":"value"}"#.data(using: .utf8)!) await transport.disconnect() } @@ -79,7 +79,7 @@ struct StdioTransportTests { try writer.writeAll(invalidJSON.data(using: .utf8)!) try writer.close() - let stream: AsyncThrowingStream = await transport.receive() + let stream: AsyncThrowingStream = await transport.receive() var iterator = stream.makeAsyncIterator() _ = try await iterator.next()