diff --git a/Sources/MCP/Base/Transports/HTTPClientTransport.swift b/Sources/MCP/Base/Transports/HTTPClientTransport.swift index f9f7e40d..678131b3 100644 --- a/Sources/MCP/Base/Transports/HTTPClientTransport.swift +++ b/Sources/MCP/Base/Transports/HTTPClientTransport.swift @@ -9,28 +9,66 @@ import Logging import FoundationNetworking #endif -public actor HTTPClientTransport: Actor, Transport { +/// An implementation of the MCP Streamable HTTP transport protocol for clients. +/// +/// This transport implements the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) +/// specification from the Model Context Protocol. +/// +/// It supports: +/// - Sending JSON-RPC messages via HTTP POST requests +/// - Receiving responses via both direct JSON responses and SSE streams +/// - Session management using the `Mcp-Session-Id` header +/// - Automatic reconnection for dropped SSE streams +/// - Platform-specific optimizations for different operating systems +/// +/// The transport supports two modes: +/// - Regular HTTP (`streaming=false`): Simple request/response pattern +/// - Streaming HTTP with SSE (`streaming=true`): Enables server-to-client push messages +/// +/// - Important: Server-Sent Events (SSE) functionality is not supported on Linux platforms. +public actor HTTPClientTransport: Transport { + /// The server endpoint URL to connect to public let endpoint: URL private let session: URLSession + + /// The session ID assigned by the server, used for maintaining state across requests public private(set) var sessionID: String? private let streaming: Bool private var streamingTask: Task? + + /// Logger instance for transport-related events public nonisolated let logger: Logger + /// Maximum time to wait for a session ID before proceeding with SSE connection + public let sseInitializationTimeout: TimeInterval + private var isConnected = false private let messageStream: AsyncThrowingStream private let messageContinuation: AsyncThrowingStream.Continuation + private var initialSessionIDSignalTask: Task? + private var initialSessionIDContinuation: CheckedContinuation? + + /// Creates a new HTTP transport client with the specified endpoint + /// + /// - Parameters: + /// - endpoint: The server URL to connect to + /// - configuration: URLSession configuration to use for HTTP requests + /// - streaming: Whether to enable SSE streaming mode (default: true) + /// - sseInitializationTimeout: Maximum time to wait for session ID before proceeding with SSE (default: 10 seconds) + /// - logger: Optional logger instance for transport events public init( endpoint: URL, configuration: URLSessionConfiguration = .default, streaming: Bool = true, + sseInitializationTimeout: TimeInterval = 10, logger: Logger? = nil ) { self.init( endpoint: endpoint, session: URLSession(configuration: configuration), streaming: streaming, + sseInitializationTimeout: sseInitializationTimeout, logger: logger ) } @@ -39,11 +77,13 @@ public actor HTTPClientTransport: Actor, Transport { endpoint: URL, session: URLSession, streaming: Bool = false, + sseInitializationTimeout: TimeInterval = 10, logger: Logger? = nil ) { self.endpoint = endpoint self.session = session self.streaming = streaming + self.sseInitializationTimeout = sseInitializationTimeout // Create message stream var continuation: AsyncThrowingStream.Continuation! @@ -58,11 +98,37 @@ public actor HTTPClientTransport: Actor, Transport { ) } + // Setup the initial session ID signal + private func setupInitialSessionIDSignal() { + self.initialSessionIDSignalTask = Task { + await withCheckedContinuation { continuation in + self.initialSessionIDContinuation = continuation + // This task will suspend here until continuation.resume() is called + } + } + } + + // Trigger the initial session ID signal when a session ID is established + private func triggerInitialSessionIDSignal() { + if let continuation = self.initialSessionIDContinuation { + continuation.resume() + self.initialSessionIDContinuation = nil // Consume the continuation + logger.debug("Initial session ID signal triggered for SSE task.") + } + } + /// Establishes connection with the transport + /// + /// This prepares the transport for communication and sets up SSE streaming + /// if streaming mode is enabled. The actual HTTP connection happens with the + /// first message sent. public func connect() async throws { guard !isConnected else { return } isConnected = true + // Setup initial session ID signal + setupInitialSessionIDSignal() + if streaming { // Start listening to server events streamingTask = Task { await startListeningForServerEvents() } @@ -72,6 +138,9 @@ public actor HTTPClientTransport: Actor, Transport { } /// Disconnects from the transport + /// + /// This terminates any active connections, cancels the streaming task, + /// and releases any resources being used by the transport. public func disconnect() async { guard isConnected else { return } isConnected = false @@ -86,10 +155,28 @@ public actor HTTPClientTransport: Actor, Transport { // Clean up message stream messageContinuation.finish() + // Cancel the initial session ID signal task if active + initialSessionIDSignalTask?.cancel() + initialSessionIDSignalTask = nil + // Resume the continuation if it's still pending to avoid leaks + initialSessionIDContinuation?.resume() + initialSessionIDContinuation = nil + logger.info("HTTP clienttransport disconnected") } /// Sends data through an HTTP POST request + /// + /// This sends a JSON-RPC message to the server via HTTP POST and processes + /// the response according to the MCP Streamable HTTP specification. It handles: + /// + /// - Adding appropriate Accept headers for both JSON and SSE + /// - Including the session ID in requests if one has been established + /// - Processing different response types (JSON vs SSE) + /// - Handling HTTP error codes according to the specification + /// + /// - Parameter data: The JSON-RPC message to send + /// - Throws: MCPError for transport failures or server errors public func send(_ data: Data) async throws { guard isConnected else { throw MCPError.internalError("Transport not connected") @@ -129,7 +216,12 @@ public actor HTTPClientTransport: Actor, Transport { // Extract session ID if present if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") { + let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID + if wasSessionIDNil { + // Trigger signal on first session ID + triggerInitialSessionIDSignal() + } logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) } @@ -161,7 +253,12 @@ public actor HTTPClientTransport: Actor, Transport { // Extract session ID if present if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") { + let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID + if wasSessionIDNil { + // Trigger signal on first session ID + triggerInitialSessionIDSignal() + } logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) } @@ -238,6 +335,14 @@ public actor HTTPClientTransport: Actor, Transport { } /// Receives data in an async sequence + /// + /// This returns an AsyncThrowingStream that emits Data objects representing + /// each JSON-RPC message received from the server. This includes: + /// + /// - Direct responses to client requests + /// - Server-initiated messages delivered via SSE streams + /// + /// - Returns: An AsyncThrowingStream of Data objects public func receive() -> AsyncThrowingStream { return messageStream } @@ -245,6 +350,14 @@ public actor HTTPClientTransport: Actor, Transport { // MARK: - SSE /// Starts listening for server events using SSE + /// + /// This establishes a long-lived HTTP connection using Server-Sent Events (SSE) + /// to enable server-to-client push messaging. It handles: + /// + /// - Waiting for session ID if needed + /// - Opening the SSE connection + /// - Automatic reconnection on connection drops + /// - Processing received events private func startListeningForServerEvents() async { #if os(Linux) // SSE is not fully supported on Linux @@ -257,6 +370,63 @@ public actor HTTPClientTransport: Actor, Transport { // This is the original code for platforms that support SSE guard isConnected else { return } + // Wait for the initial session ID signal, but only if sessionID isn't already set + if self.sessionID == nil, let signalTask = self.initialSessionIDSignalTask { + logger.debug("SSE streaming task waiting for initial sessionID signal...") + + // Race the signalTask against a timeout + let timeoutTask = Task { + try? await Task.sleep(for: .seconds(self.sseInitializationTimeout)) + return false + } + + let signalCompletionTask = Task { + await signalTask.value + return true // Indicates signal received + } + + // Use TaskGroup to race the two tasks + var signalReceived = false + do { + signalReceived = try await withThrowingTaskGroup(of: Bool.self) { group in + group.addTask { + await signalCompletionTask.value + } + group.addTask { + await timeoutTask.value + } + + // Take the first result and cancel the other task + if let firstResult = try await group.next() { + group.cancelAll() + return firstResult + } + return false + } + } catch { + logger.error("Error while waiting for session ID signal: \(error)") + } + + // Clean up tasks + timeoutTask.cancel() + + if signalReceived { + logger.debug("SSE streaming task proceeding after initial sessionID signal.") + } else { + logger.warning( + "Timeout waiting for initial sessionID signal. SSE stream will proceed (sessionID might be nil)." + ) + } + } else if self.sessionID != nil { + logger.debug( + "Initial sessionID already available. Proceeding with SSE streaming task immediately." + ) + } else { + logger.info( + "Proceeding with SSE connection attempt; sessionID is nil. This might be expected for stateless servers or if initialize hasn't provided one yet." + ) + } + // Retry loop for connection drops while isConnected && !Task.isCancelled { do { @@ -274,6 +444,11 @@ public actor HTTPClientTransport: Actor, Transport { #if !os(Linux) /// Establishes an SSE connection to the server + /// + /// This initiates a GET request to the server endpoint with appropriate + /// headers to establish an SSE stream according to the MCP specification. + /// + /// - Throws: MCPError for connection failures or server errors private func connectToEventStream() async throws { guard isConnected else { return } @@ -309,13 +484,23 @@ public actor HTTPClientTransport: Actor, Transport { // Extract session ID if present if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") { + let wasSessionIDNil = (self.sessionID == nil) self.sessionID = newSessionID + if wasSessionIDNil { + // Trigger signal on first session ID, though this is unlikely to happen here + // as GET usually follows a POST that would have already set the session ID + triggerInitialSessionIDSignal() + } logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"]) } try await self.processSSE(stream) } + /// Processes an SSE byte stream, extracting events and delivering them + /// + /// - Parameter stream: The URLSession.AsyncBytes stream to process + /// - Throws: Error for stream processing failures private func processSSE(_ stream: URLSession.AsyncBytes) async throws { do { for try await event in stream.events { diff --git a/Tests/MCPTests/HTTPClientTransportTests.swift b/Tests/MCPTests/HTTPClientTransportTests.swift index 05721398..c2149bb7 100644 --- a/Tests/MCPTests/HTTPClientTransportTests.swift +++ b/Tests/MCPTests/HTTPClientTransportTests.swift @@ -415,24 +415,45 @@ import Testing endpoint: testEndpoint, configuration: configuration, streaming: true, + sseInitializationTimeout: 1, logger: nil ) let eventString = "id: event1\ndata: {\"key\":\"value\"}\n\n" let sseEventData = eventString.data(using: .utf8)! + // First, set up a handler for the initial POST that will provide a session ID await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in + let response = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: [ + "Content-Type": "text/plain", + "Mcp-Session-Id": "test-session-123", + ])! + return (response, Data()) + } + + // Connect and send a dummy message to get the session ID + try await transport.connect() + try await transport.send(Data()) + + // Now set up the handler for the SSE GET request + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint, sseEventData] (request: URLRequest) in // sseEventData is now empty Data() #expect(request.url == testEndpoint) #expect(request.httpMethod == "GET") #expect(request.value(forHTTPHeaderField: "Accept") == "text/event-stream") + #expect( + request.value(forHTTPHeaderField: "Mcp-Session-Id") == "test-session-123") + let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: ["Content-Type": "text/event-stream"])! - return (response, sseEventData) + + return (response, sseEventData) // Will return empty Data for SSE } - try await transport.connect() try await Task.sleep(for: .milliseconds(100)) let stream = await transport.receive() @@ -442,7 +463,10 @@ import Testing let receivedData = try await iterator.next() #expect(receivedData == expectedData) + + await transport.disconnect() } + @Test("Receive Server-Sent Event (SSE) (CR-NL)", .httpClientTransportSetup) func testReceiveSSE_CRNL() async throws { let configuration = URLSessionConfiguration.ephemeral @@ -452,24 +476,46 @@ import Testing endpoint: testEndpoint, configuration: configuration, streaming: true, + sseInitializationTimeout: 1, logger: nil ) let eventString = "id: event1\r\ndata: {\"key\":\"value\"}\r\n\n" let sseEventData = eventString.data(using: .utf8)! + // First, set up a handler for the initial POST that will provide a session ID + // Use text/plain to prevent its (empty) body from being yielded to messageStream await MockURLProtocol.requestHandlerStorage.setHandler { [testEndpoint] (request: URLRequest) in + let response = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: [ + "Content-Type": "text/plain", + "Mcp-Session-Id": "test-session-123", + ])! + return (response, Data()) + } + + // Connect and send a dummy message to get the session ID + try await transport.connect() + try await transport.send(Data()) + + // Now set up the handler for the SSE GET request + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint, sseEventData] (request: URLRequest) in #expect(request.url == testEndpoint) #expect(request.httpMethod == "GET") #expect(request.value(forHTTPHeaderField: "Accept") == "text/event-stream") + #expect( + request.value(forHTTPHeaderField: "Mcp-Session-Id") == "test-session-123") + let response = HTTPURLResponse( url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", headerFields: ["Content-Type": "text/event-stream"])! + return (response, sseEventData) } - try await transport.connect() try await Task.sleep(for: .milliseconds(100)) let stream = await transport.receive() @@ -479,6 +525,8 @@ import Testing let receivedData = try await iterator.next() #expect(receivedData == expectedData) + + await transport.disconnect() } #endif // !canImport(FoundationNetworking)