diff --git a/README.md b/README.md index c5ae240e..e48732c2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,29 @@ try await client.connect(transport: transport) let result = try await client.initialize() ``` +### Streaming HTTP Transport + +The HTTP transport supports streaming mode for real-time communication using Server-Sent Events (SSE): + +```swift +import MCP + +// Create a streaming HTTP transport +let transport = HTTPClientTransport( + endpoint: URL(string: "http://localhost:8080")!, +) + +// Initialize the client with streaming transport +let client = Client(name: "MyApp", version: "1.0.0") +try await client.connect(transport: transport) + +// Initialize the connection +let result = try await client.initialize() + +// The transport will automatically handle SSE events +// and deliver them through the client's notification handlers +``` + ### Basic Server Setup ```swift diff --git a/Sources/MCP/Base/Transports/HTTPClientTransport.swift b/Sources/MCP/Base/Transports/HTTPClientTransport.swift index 33f40e65..9cf93a0c 100644 --- a/Sources/MCP/Base/Transports/HTTPClientTransport.swift +++ b/Sources/MCP/Base/Transports/HTTPClientTransport.swift @@ -21,7 +21,7 @@ public actor HTTPClientTransport: Actor, Transport { public init( endpoint: URL, configuration: URLSessionConfiguration = .default, - streaming: Bool = false, + streaming: Bool = true, logger: Logger? = nil ) { self.init( @@ -269,7 +269,7 @@ public actor HTTPClientTransport: Actor, Transport { if line.hasSuffix("\r") { line = line.dropLast() } - + // Lines starting with ":" are comments if line.hasPrefix(":") { continue } diff --git a/Tests/MCPTests/HTTPClientTransportTests.swift b/Tests/MCPTests/HTTPClientTransportTests.swift index fdbff85d..8085f5a5 100644 --- a/Tests/MCPTests/HTTPClientTransportTests.swift +++ b/Tests/MCPTests/HTTPClientTransportTests.swift @@ -51,7 +51,11 @@ import Testing func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { guard let handler = requestHandler else { - throw MockURLProtocolError.noRequestHandler + throw NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [ + NSLocalizedDescriptionKey: "No request handler set" + ]) } return try await handler(request) } @@ -123,11 +127,6 @@ import Testing override func stopLoading() {} } - enum MockURLProtocolError: Swift.Error { - case noRequestHandler - case invalidURL - } - // MARK: - @Suite("HTTP Client Transport Tests", .serialized) @@ -140,7 +139,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration, streaming: false, logger: nil) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() await transport.disconnect() @@ -152,7 +155,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration, streaming: false, logger: nil) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() let messageData = #"{"jsonrpc":"2.0","method":"initialize","id":1}"#.data(using: .utf8)! @@ -190,7 +197,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration, streaming: false, logger: nil) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() let messageData = #"{"jsonrpc":"2.0","method":"initialize","id":1}"#.data(using: .utf8)! @@ -220,7 +231,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration, streaming: false, logger: nil) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() let initialSessionID = "existing-session-abc" @@ -265,7 +280,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() let messageData = #"{"jsonrpc":"2.0","method":"test","id":3}"#.data(using: .utf8)! @@ -298,7 +317,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() let messageData = #"{"jsonrpc":"2.0","method":"test","id":4}"#.data(using: .utf8)! @@ -331,7 +354,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration) + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) try await transport.connect() let initialSessionID = "expired-session-xyz" @@ -385,8 +412,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration, streaming: true, - logger: nil) + endpoint: testEndpoint, + configuration: configuration, + streaming: true, + logger: nil + ) let eventString = "id: event1\ndata: {\"key\":\"value\"}\n\n" let sseEventData = eventString.data(using: .utf8)! @@ -419,8 +449,11 @@ import Testing configuration.protocolClasses = [MockURLProtocol.self] let transport = HTTPClientTransport( - endpoint: testEndpoint, configuration: configuration, streaming: true, - logger: nil) + endpoint: testEndpoint, + configuration: configuration, + streaming: true, + logger: nil + ) let eventString = "id: event1\r\ndata: {\"key\":\"value\"}\r\n\n" let sseEventData = eventString.data(using: .utf8)! @@ -448,6 +481,152 @@ import Testing #expect(receivedData == expectedData) } #endif // !canImport(FoundationNetworking) - } + @Test( + "Client with HTTP Transport complete flow", .httpClientTransportSetup, + .timeLimit(.minutes(1))) + func testClientFlow() async throws { + let configuration = URLSessionConfiguration.ephemeral + configuration.protocolClasses = [MockURLProtocol.self] + + let transport = HTTPClientTransport( + endpoint: testEndpoint, + configuration: configuration, + streaming: false, + logger: nil + ) + + let client = Client(name: "TestClient", version: "1.0.0") + + // Use an actor to track request sequence + actor RequestTracker { + enum RequestType { + case initialize + case callTool + } + + private(set) var lastRequest: RequestType? + + func setRequest(_ type: RequestType) { + lastRequest = type + } + + func getLastRequest() -> RequestType? { + return lastRequest + } + } + + let tracker = RequestTracker() + + // Setup mock responses + await MockURLProtocol.requestHandlerStorage.setHandler { + [testEndpoint, tracker] (request: URLRequest) in + switch request.httpMethod { + case "GET": + #expect( + request.allHTTPHeaderFields?["Accept"]?.contains("text/event-stream") + == true) + case "POST": + #expect( + request.allHTTPHeaderFields?["Accept"]?.contains("application/json") == true + ) + default: + Issue.record( + "Unsupported HTTP method \(String(describing: request.httpMethod))") + } + + #expect(request.url == testEndpoint) + + let bodyData = request.readBody() + + guard let bodyData = bodyData, + let json = try JSONSerialization.jsonObject(with: bodyData) as? [String: Any], + let method = json["method"] as? String + else { + throw NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [ + NSLocalizedDescriptionKey: "Invalid JSON-RPC message \(#file):\(#line)" + ]) + } + + if method == "initialize" { + await tracker.setRequest(.initialize) + + let requestID = json["id"] as! String + let result = Initialize.Result( + protocolVersion: Version.latest, + capabilities: .init(tools: .init()), + serverInfo: .init(name: "Mock Server", version: "0.0.1"), + instructions: nil + ) + let response = Initialize.response(id: .string(requestID), result: result) + let responseData = try JSONEncoder().encode(response) + + let httpResponse = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (httpResponse, responseData) + } else if method == "tools/call" { + // Verify initialize was called first + if let lastRequest = await tracker.getLastRequest(), lastRequest != .initialize + { + #expect(Bool(false), "Initialize should be called before callTool") + } + + await tracker.setRequest(.callTool) + + let params = json["params"] as? [String: Any] + let toolName = params?["name"] as? String + #expect(toolName == "calculator") + + let requestID = json["id"] as! String + let result = CallTool.Result(content: [.text("42")]) + let response = CallTool.response(id: .string(requestID), result: result) + let responseData = try JSONEncoder().encode(response) + + let httpResponse = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (httpResponse, responseData) + } else if method == "notifications/initialized" { + // Ignore initialized notifications + let httpResponse = HTTPURLResponse( + url: testEndpoint, statusCode: 200, httpVersion: "HTTP/1.1", + headerFields: ["Content-Type": "application/json"])! + return (httpResponse, Data()) + } else { + throw NSError( + domain: "MockURLProtocolError", code: 0, + userInfo: [ + NSLocalizedDescriptionKey: + "Unexpected request method: \(method) \(#file):\(#line)" + ]) + } + } + + // Execute the complete flow + try await client.connect(transport: transport) + + // Step 1: Initialize client + let initResult = try await client.initialize() + #expect(initResult.protocolVersion == Version.latest) + #expect(initResult.capabilities.tools != nil) + + // Step 2: Call a tool + let toolResult = try await client.callTool(name: "calculator") + #expect(toolResult.content.count == 1) + if case let .text(text) = toolResult.content[0] { + #expect(text == "42") + } else { + #expect(Bool(false), "Expected text content") + } + + // Step 3: Verify request sequence + #expect(await tracker.getLastRequest() == .callTool) + + // Step 4: Disconnect + await client.disconnect() + } + } #endif // swift(>=6.1)