Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions Sources/MCP/Base/Transports.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ public protocol Transport: Actor {
/// Disconnects from the transport
func disconnect() async

/// Sends a message string
func send(_ message: String) async throws
/// Sends data
func send(_ data: Data) async throws

/// Receives message strings as an async sequence
func receive() -> AsyncThrowingStream<String, Swift.Error>
/// Receives data in an async sequence
func receive() -> AsyncThrowingStream<Data, Swift.Error>
}

/// Standard input/output transport implementation
Expand All @@ -33,8 +33,8 @@ public actor StdioTransport: Transport {
public nonisolated let logger: Logger

private var isConnected = false
private let messageStream: AsyncStream<String>
private let messageContinuation: AsyncStream<String>.Continuation
private let messageStream: AsyncStream<Data>
private let messageContinuation: AsyncStream<Data>.Continuation

public init(
input: FileDescriptor = FileDescriptor.standardInput,
Expand All @@ -50,7 +50,7 @@ public actor StdioTransport: Transport {
factory: { _ in SwiftLogNoOpLogHandler() })

// Create message stream
var continuation: AsyncStream<String>.Continuation!
var continuation: AsyncStream<Data>.Continuation!
self.messageStream = AsyncStream { continuation = $0 }
self.messageContinuation = continuation
}
Expand Down Expand Up @@ -105,15 +105,13 @@ public actor StdioTransport: Transport {
let messageData = pendingData[..<newlineIndex]
pendingData = pendingData[(newlineIndex + 1)...]

if let message = String(data: messageData, encoding: .utf8),
!message.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
{
logger.debug("Message received", metadata: ["message": "\(message)"])
messageContinuation.yield(message)
if !messageData.isEmpty {
logger.debug("Message received", metadata: ["size": "\(messageData.count)"])
messageContinuation.yield(Data(messageData))
}
}
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms backoff
try? await Task.sleep(for: .milliseconds(10))
continue
} catch {
if !Task.isCancelled {
Expand All @@ -133,17 +131,16 @@ public actor StdioTransport: Transport {
logger.info("Transport disconnected")
}

public func send(_ message: String) async throws {
public func send(_ message: Data) async throws {
guard isConnected else {
throw Error.transportError(Errno.socketNotConnected)
}

let message = message + "\n"
guard let data = message.data(using: .utf8) else {
throw Error.transportError(Errno.invalidArgument)
}
// Add newline as delimiter
var messageWithNewline = message
messageWithNewline.append(UInt8(ascii: "\n"))

var remaining = data
var remaining = messageWithNewline
while !remaining.isEmpty {
do {
let written = try remaining.withUnsafeBytes { buffer in
Expand All @@ -153,15 +150,15 @@ public actor StdioTransport: Transport {
remaining = remaining.dropFirst(written)
}
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
try await Task.sleep(nanoseconds: 10_000_000) // 10ms backoff
try await Task.sleep(for: .milliseconds(10))
continue
} catch {
throw Error.transportError(error)
}
}
}

public func receive() -> AsyncThrowingStream<String, Swift.Error> {
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
return AsyncThrowingStream { continuation in
Task {
for await message in messageStream {
Expand All @@ -182,8 +179,8 @@ public actor StdioTransport: Transport {
public nonisolated let logger: Logger

private var isConnected = false
private let messageStream: AsyncThrowingStream<String, Swift.Error>
private let messageContinuation: AsyncThrowingStream<String, Swift.Error>.Continuation
private let messageStream: AsyncThrowingStream<Data, Swift.Error>
private let messageContinuation: AsyncThrowingStream<Data, Swift.Error>.Continuation

// Track connection state for continuations
private var connectionContinuationResumed = false
Expand All @@ -198,7 +195,7 @@ public actor StdioTransport: Transport {
)

// Create message stream
var continuation: AsyncThrowingStream<String, Swift.Error>.Continuation!
var continuation: AsyncThrowingStream<Data, Swift.Error>.Continuation!
self.messageStream = AsyncThrowingStream { continuation = $0 }
self.messageContinuation = continuation
}
Expand Down Expand Up @@ -289,14 +286,14 @@ public actor StdioTransport: Transport {
logger.info("Network transport disconnected")
}

public func send(_ message: String) async throws {
public func send(_ message: Data) async throws {
guard isConnected else {
throw MCP.Error.internalError("Transport not connected")
}

guard let data = (message + "\n").data(using: .utf8) else {
throw MCP.Error.internalError("Failed to encode message")
}
// Add newline as delimiter
var messageWithNewline = message
messageWithNewline.append(UInt8(ascii: "\n"))

// Use a local actor-isolated variable to track continuation state
var sendContinuationResumed = false
Expand All @@ -309,7 +306,7 @@ public actor StdioTransport: Transport {
}

connection.send(
content: data,
content: messageWithNewline,
completion: .contentProcessed { [weak self] error in
guard let self = self else { return }

Expand All @@ -329,7 +326,7 @@ public actor StdioTransport: Transport {
}
}

public func receive() -> AsyncThrowingStream<String, Swift.Error> {
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
return AsyncThrowingStream { continuation in
Task {
do {
Expand Down Expand Up @@ -357,11 +354,10 @@ public actor StdioTransport: Transport {
let messageData = buffer[..<newlineIndex]
buffer = buffer[(newlineIndex + 1)...]

if let message = String(data: messageData, encoding: .utf8),
!message.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
{
logger.debug("Message received", metadata: ["message": "\(message)"])
messageContinuation.yield(message)
if !messageData.isEmpty {
logger.debug(
"Message received", metadata: ["size": "\(messageData.count)"])
messageContinuation.yield(Data(messageData))
}
}
} catch let error as NWError {
Expand Down
20 changes: 6 additions & 14 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,10 @@ public actor Client {

do {
let stream = await connection.receive()
for try await string in stream {
for try await data in stream {
if Task.isCancelled { break } // Check inside loop too

// Decode and handle incoming message
guard let data = string.data(using: .utf8) else {
throw Error.parseError("Invalid UTF-8 data")
}

// Attempt to decode string data as AnyResponse or AnyMessage
// Attempt to decode data as AnyResponse or AnyMessage
let decoder = JSONDecoder()
if let response = try? decoder.decode(AnyResponse.self, from: data),
let request = pendingRequests[response.id]
Expand All @@ -207,7 +202,7 @@ public actor Client {
}
}
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms
try? await Task.sleep(for: .milliseconds(10))
continue
} catch {
await logger?.error(
Expand Down Expand Up @@ -256,22 +251,19 @@ public actor Client {
}

let requestData = try JSONEncoder().encode(request)
guard let requestString = String(data: requestData, encoding: .utf8) else {
throw Error.internalError("Failed to encode request")
}

// Store the pending request first
return try await withCheckedThrowingContinuation { continuation in
// Store the pending request first
Task {
self.addPendingRequest(
id: request.id,
continuation: continuation,
type: M.Result.self
)

// Send the request
// Send the request data
do {
try await connection.send(requestString)
try await connection.send(requestData)
} catch {
continuation.resume(throwing: error)
self.removePendingRequest(id: request.id)
Expand Down
23 changes: 7 additions & 16 deletions Sources/MCP/Server/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,11 @@ public actor Server {
task = Task {
do {
let stream = await transport.receive()
for try await string in stream {
for try await data in stream {
if Task.isCancelled { break } // Check cancellation inside loop

var requestID: ID?
do {
guard let data = string.data(using: .utf8) else {
throw Error.parseError("Invalid UTF-8 data")
}

// Attempt to decode string data as AnyRequest or AnyMessage
let decoder = JSONDecoder()
if let request = try? decoder.decode(AnyRequest.self, from: data) {
Expand All @@ -203,7 +199,7 @@ public actor Server {
}
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
// Resource temporarily unavailable, retry after a short delay
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms
try? await Task.sleep(for: .milliseconds(10))
continue
} catch {
await logger?.error(
Expand Down Expand Up @@ -266,19 +262,17 @@ public actor Server {

// MARK: - Sending

/// Send a response to a client
/// Send a response to a request
public func send<M: Method>(_ response: Response<M>) async throws {
guard let connection = connection else {
throw Error.internalError("Server connection not initialized")
}

let encoder = JSONEncoder()
encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes]

let responseData = try encoder.encode(response)

if let responseStr = String(data: responseData, encoding: .utf8) {
try await connection.send(responseStr)
}
try await connection.send(responseData)
}

/// Send a notification to connected clients
Expand All @@ -291,10 +285,7 @@ public actor Server {
encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes]

let notificationData = try encoder.encode(notification)

if let notificationStr = String(data: notificationData, encoding: .utf8) {
try await connection.send(notificationStr)
}
try await connection.send(notificationData)
}

// MARK: -
Expand Down Expand Up @@ -407,7 +398,7 @@ public actor Server {

// Send initialized notification after a short delay
Task {
try? await Task.sleep(nanoseconds: 100_000_000) // 100ms
try? await Task.sleep(for: .milliseconds(10))
try? await self.notify(InitializedNotification.message())
}

Expand Down
24 changes: 12 additions & 12 deletions Tests/MCPTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ struct ClientTests {

try await client.connect(transport: transport)
// Small delay to ensure message loop is started
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
try await Task.sleep(for: .milliseconds(10))

// Create a task for initialize that we'll cancel
let initTask = Task {
try await client.initialize()
}

// Give it a moment to send the request
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
try await Task.sleep(for: .milliseconds(10))

#expect(await transport.sentMessages.count == 1)
#expect(await transport.sentMessages[0].contains(Initialize.name))
#expect(await transport.sentMessages[0].contains(client.name))
#expect(await transport.sentMessages[0].contains(client.version))
#expect(await transport.sentMessages.first?.contains(Initialize.name) == true)
#expect(await transport.sentMessages.first?.contains(client.name) == true)
#expect(await transport.sentMessages.first?.contains(client.version) == true)

// Cancel the initialize task
initTask.cancel()

// Disconnect client to clean up message loop and give time for continuation cleanup
await client.disconnect()
try await Task.sleep(nanoseconds: 50_000_000) // 50ms
try await Task.sleep(for: .milliseconds(50))
}

@Test(
Expand All @@ -60,25 +60,25 @@ struct ClientTests {

try await client.connect(transport: transport)
// Small delay to ensure message loop is started
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
try await Task.sleep(for: .milliseconds(10))

// Create a task for the ping that we'll cancel
let pingTask = Task {
try await client.ping()
}

// Give it a moment to send the request
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
try await Task.sleep(for: .milliseconds(10))

#expect(await transport.sentMessages.count == 1)
#expect(await transport.sentMessages[0].contains(Ping.name))
#expect(await transport.sentMessages.first?.contains(Ping.name) == true)

// Cancel the ping task
pingTask.cancel()

// Disconnect client to clean up message loop and give time for continuation cleanup
await client.disconnect()
try await Task.sleep(nanoseconds: 50_000_000) // 50ms
try await Task.sleep(for: .milliseconds(50))
}

@Test("Connection failure handling")
Expand Down Expand Up @@ -168,7 +168,7 @@ struct ClientTests {

// Wait a bit for any setup to complete
try await Task.sleep(for: .milliseconds(10))

// Send the listPrompts request and immediately provide an error response
let promptsTask = Task {
do {
Expand All @@ -187,7 +187,7 @@ struct ClientTests {
id: decodedRequest.id,
error: Error.methodNotFound("Test: Prompts capability not available")
)
try await transport.queueResponse(errorResponse)
try await transport.queue(response: errorResponse)

// Try the request now that we have a response queued
do {
Expand Down
Loading