diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b67469b..d51838a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,16 +1,14 @@ name: test on: + push: + branches: [ main ] pull_request: - push: { branches: [ main ] } - + branches: [ main ] + workflow_dispatch: jobs: + lint: + uses: graphqlswift/ci/.github/workflows/lint.yaml@main test: - strategy: - matrix: - os: [ubuntu-latest, macos-latest] - runs-on: ${{ matrix.os }} - steps: - - uses: fwal/setup-swift@v1 - - uses: actions/checkout@v2 - - name: Run tests - run: swift test + uses: graphqlswift/ci/.github/workflows/test.yaml@main + with: + include_android: false diff --git a/Package.resolved b/Package.resolved index 7ce8815..3e34fcf 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,61 +1,33 @@ { - "object": { - "pins": [ - { - "package": "Graphiti", - "repositoryURL": "https://github.com/GraphQLSwift/Graphiti.git", - "state": { - "branch": null, - "revision": "c9bc9d1cc9e62e71a824dc178630bfa8b8a6e2a4", - "version": "1.0.0" - } - }, - { - "package": "GraphQL", - "repositoryURL": "https://github.com/GraphQLSwift/GraphQL.git", - "state": { - "branch": null, - "revision": "283cc4de56b994a00b2724328221b7a1bc846ddc", - "version": "2.2.1" - } - }, - { - "package": "GraphQLRxSwift", - "repositoryURL": "https://github.com/GraphQLSwift/GraphQLRxSwift.git", - "state": { - "branch": null, - "revision": "c7ec6595f92ef5d77c06852e4acc4cd46a753622", - "version": "0.0.4" - } - }, - { - "package": "RxSwift", - "repositoryURL": "https://github.com/ReactiveX/RxSwift.git", - "state": { - "branch": null, - "revision": "b4307ba0b6425c0ba4178e138799946c3da594f8", - "version": "6.5.0" - } - }, - { - "package": "swift-collections", - "repositoryURL": "https://github.com/apple/swift-collections", - "state": { - "branch": null, - "revision": "48254824bb4248676bf7ce56014ff57b142b77eb", - "version": "1.0.2" - } - }, - { - "package": "swift-nio", - "repositoryURL": "https://github.com/apple/swift-nio.git", - "state": { - "branch": null, - "revision": "6aa9347d9bc5bbfe6a84983aec955c17ffea96ef", - "version": "2.33.0" - } + "originHash" : "30951f6d77c03868bb74b0838ce93637391a168c6668a029c8a8a1dd9fb01aa5", + "pins" : [ + { + "identity" : "graphiti", + "kind" : "remoteSourceControl", + "location" : "https://github.com/GraphQLSwift/Graphiti.git", + "state" : { + "revision" : "a23a3d232df202fc158ad2d698926325b470523c", + "version" : "3.0.0" } - ] - }, - "version": 1 + }, + { + "identity" : "graphql", + "kind" : "remoteSourceControl", + "location" : "https://github.com/GraphQLSwift/GraphQL.git", + "state" : { + "revision" : "0fe18bc0bbbc9ab8929c285f419adea7c8fc7da2", + "version" : "4.0.1" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections", + "state" : { + "revision" : "8c0c0a8b49e080e54e5e328cc552821ff07cd341", + "version" : "1.2.1" + } + } + ], + "version" : 3 } diff --git a/Package.swift b/Package.swift index 2d77e15..db11fb4 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,7 @@ import PackageDescription let package = Package( name: "GraphQLTransportWS", + platforms: [.macOS(.v10_15)], products: [ .library( name: "GraphQLTransportWS", @@ -11,25 +12,21 @@ let package = Package( ), ], dependencies: [ - .package(name: "Graphiti", url: "https://github.com/GraphQLSwift/Graphiti.git", from: "1.0.0"), - .package(name: "GraphQL", url: "https://github.com/GraphQLSwift/GraphQL.git", from: "2.2.1"), - .package(name: "GraphQLRxSwift", url: "https://github.com/GraphQLSwift/GraphQLRxSwift.git", from: "0.0.4"), - .package(name: "RxSwift", url: "https://github.com/ReactiveX/RxSwift.git", from: "6.1.0"), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), + .package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"), + .package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.1"), ], targets: [ .target( name: "GraphQLTransportWS", dependencies: [ .product(name: "Graphiti", package: "Graphiti"), - .product(name: "GraphQLRxSwift", package: "GraphQLRxSwift"), .product(name: "GraphQL", package: "GraphQL"), - .product(name: "NIO", package: "swift-nio"), - .product(name: "RxSwift", package: "RxSwift") - ]), + ] + ), .testTarget( name: "GraphQLTransportWSTests", dependencies: ["GraphQLTransportWS"] ), - ] + ], + swiftLanguageVersions: [.v5, .version("6")] ) diff --git a/README.md b/README.md index 65b6646..f9d2d2b 100644 --- a/README.md +++ b/README.md @@ -27,32 +27,32 @@ import GraphQLTransportWS /// Messenger wrapper for WebSockets class WebSocketMessenger: Messenger { private weak var websocket: WebSocket? - private var onReceive: (String) -> Void = { _ in } - + private var onReceive: (String) async throws -> Void = { _ in } + init(websocket: WebSocket) { self.websocket = websocket websocket.onText { _, message in - self.onReceive(message) + try await self.onReceive(message) } } - - func send(_ message: S) where S: Collection, S.Element == Character { + + func send(_ message: S) where S: Collection, S.Element == Character async throws { guard let websocket = websocket else { return } - websocket.send(message) + try await websocket.send(message) } - - func onReceive(callback: @escaping (String) -> Void) { + + func onReceive(callback: @escaping (String) async throws -> Void) { self.onReceive = callback } - - func error(_ message: String, code: Int) { + + func error(_ message: String, code: Int) async throws { guard let websocket = websocket else { return } - websocket.send("\(code): \(message)") + try await websocket.send("\(code): \(message)") } - - func close() { + + func close() async throws { guard let websocket = websocket else { return } - _ = websocket.close() + try await websocket.close() } } ``` @@ -67,7 +67,7 @@ routes.webSocket( let server = GraphQLTransportWS.Server( messenger: messenger, onExecute: { graphQLRequest in - api.execute( + try await api.execute( request: graphQLRequest.query, context: context, on: self.eventLoop, @@ -76,7 +76,7 @@ routes.webSocket( ) }, onSubscribe: { graphQLRequest in - api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context, on: self.eventLoop, @@ -128,8 +128,8 @@ If the `payload` field is not required on your server, you may make Server's gen ## Memory Management -Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket -implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the +Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket +implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes, diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index c4eaa62..bcd82f1 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -5,16 +5,16 @@ import GraphQL public class Client { // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - var onConnectionAck: (ConnectionAckResponse, Client) -> Void = { _, _ in } - var onNext: (NextResponse, Client) -> Void = { _, _ in } - var onError: (ErrorResponse, Client) -> Void = { _, _ in } - var onComplete: (CompleteResponse, Client) -> Void = { _, _ in } - var onMessage: (String, Client) -> Void = { _, _ in } - + + var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } + var onNext: (NextResponse, Client) async throws -> Void = { _, _ in } + var onError: (ErrorResponse, Client) async throws -> Void = { _, _ in } + var onComplete: (CompleteResponse, Client) async throws -> Void = { _, _ in } + var onMessage: (String, Client) async throws -> Void = { _, _ in } + let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() - + /// Create a new client. /// /// - Parameters: @@ -24,123 +24,122 @@ public class Client { ) { self.messenger = messenger messenger.onReceive { message in - self.onMessage(message, self) - + try await self.onMessage(message, self) + // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages return } - + guard let json = message.data(using: .utf8) else { - self.error(.invalidEncoding()) + try await self.error(.invalidEncoding()) return } - + let response: Response do { response = try self.decoder.decode(Response.self, from: json) - } - catch { - self.error(.noType()) + } catch { + try await self.error(.noType()) return } - + switch response.type { - case .connectionAck: - guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .connectionAck)) - return - } - self.onConnectionAck(connectionAckResponse, self) - case .next: - guard let nextResponse = try? self.decoder.decode(NextResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .next)) - return - } - self.onNext(nextResponse, self) - case .error: - guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .error)) - return - } - self.onError(errorResponse, self) - case .complete: - guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .complete)) - return - } - self.onComplete(completeResponse, self) - case .unknown: - self.error(.invalidType()) + case .connectionAck: + guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .connectionAck)) + return + } + try await self.onConnectionAck(connectionAckResponse, self) + case .next: + guard let nextResponse = try? self.decoder.decode(NextResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .next)) + return + } + try await self.onNext(nextResponse, self) + case .error: + guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .error)) + return + } + try await self.onError(errorResponse, self) + case .complete: + guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .complete)) + return + } + try await self.onComplete(completeResponse, self) + case .unknown: + try await self.error(.invalidType()) } } } - + /// Define the callback run on receipt of a `connection_ack` message /// - Parameter callback: The callback to assign - public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) -> Void) { - self.onConnectionAck = callback + public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) async throws -> Void) { + onConnectionAck = callback } - + /// Define the callback run on receipt of a `next` message /// - Parameter callback: The callback to assign - public func onNext(_ callback: @escaping (NextResponse, Client) -> Void) { - self.onNext = callback + public func onNext(_ callback: @escaping (NextResponse, Client) async throws -> Void) { + onNext = callback } - + /// Define the callback run on receipt of an `error` message /// - Parameter callback: The callback to assign - public func onError(_ callback: @escaping (ErrorResponse, Client) -> Void) { - self.onError = callback + public func onError(_ callback: @escaping (ErrorResponse, Client) async throws -> Void) { + onError = callback } - + /// Define the callback run on receipt of a `complete` message /// - Parameter callback: The callback to assign - public func onComplete(_ callback: @escaping (CompleteResponse, Client) -> Void) { - self.onComplete = callback + public func onComplete(_ callback: @escaping (CompleteResponse, Client) async throws -> Void) { + onComplete = callback } - + /// Define the callback run on receipt of any message /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String, Client) -> Void) { - self.onMessage = callback + public func onMessage(_ callback: @escaping (String, Client) async throws -> Void) { + onMessage = callback } - + /// Send a `connection_init` request through the messenger - public func sendConnectionInit(payload: InitPayload) { + public func sendConnectionInit(payload: InitPayload) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionInitRequest( payload: payload ).toJSON(encoder) ) } - + /// Send a `subscribe` request through the messenger - public func sendStart(payload: GraphQLRequest, id: String) { + public func sendStart(payload: GraphQLRequest, id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( SubscribeRequest( payload: payload, id: id ).toJSON(encoder) ) } - + /// Send a `complete` request through the messenger - public func sendStop(id: String) { + public func sendStop(id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( CompleteRequest( id: id ).toJSON(encoder) ) } - + /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLTransportWSError) { + private func error(_ error: GraphQLTransportWSError) async throws { guard let messenger = messenger else { return } - messenger.error(error.message, code: error.code.rawValue) + try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index 2ccf56f..3fda638 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -3,82 +3,82 @@ import GraphQL struct GraphQLTransportWSError: Error { let message: String let code: ErrorCode - + init(_ message: String, code: ErrorCode) { self.message = message self.code = code } - + static func unauthorized() -> Self { return self.init( "Unauthorized", code: .unauthorized ) } - + static func notInitialized() -> Self { return self.init( "Connection not initialized", code: .notInitialized ) } - + static func tooManyInitializations() -> Self { return self.init( "Too many initialisation requests", code: .tooManyInitializations ) } - + static func subscriberAlreadyExists(id: String) -> Self { return self.init( "Subscriber for \(id) already exists", code: .subscriberAlreadyExists ) } - + static func invalidEncoding() -> Self { return self.init( "Message was not encoded in UTF8", code: .invalidEncoding ) } - + static func noType() -> Self { return self.init( "Message has no 'type' field", code: .noType ) } - + static func invalidType() -> Self { return self.init( "Message 'type' value does not match supported types", code: .invalidType ) } - + static func invalidRequestFormat(messageType: RequestMessageType) -> Self { return self.init( "Request message doesn't match '\(messageType.rawValue)' JSON format", code: .invalidRequestFormat ) } - + static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { return self.init( "Response message doesn't match '\(messageType.rawValue)' JSON format", code: .invalidResponseFormat ) } - + static func internalAPIStreamIssue(errors: [GraphQLError]) -> Self { return self.init( - "API Response did not result in a stream type, contained errors\n \(errors.map { $0.message}.joined(separator: "\n"))", + "API Response did not result in a stream type, contained errors\n \(errors.map { $0.message }.joined(separator: "\n"))", code: .internalAPIStreamIssue ) } - + static func graphQLError(_ error: Error) -> Self { return self.init( "\(error)", @@ -88,28 +88,28 @@ struct GraphQLTransportWSError: Error { } /// Error codes for miscellaneous issues -public enum ErrorCode: Int, CustomStringConvertible { +public enum ErrorCode: Int, CustomStringConvertible, Sendable { // Miscellaneous case miscellaneous = 4400 - + // Internal errors case graphQLError = 4401 case internalAPIStreamIssue = 4402 - + // Message errors case invalidEncoding = 4410 case noType = 4411 case invalidType = 4412 case invalidRequestFormat = 4413 case invalidResponseFormat = 4414 - + // Initialization errors case unauthorized = 4430 case notInitialized = 4431 case tooManyInitializations = 4432 case subscriberAlreadyExists = 4433 - + public var description: String { - return "\(self.rawValue)" + return "\(rawValue)" } } diff --git a/Sources/GraphQLTransportWS/InitPayloads.swift b/Sources/GraphQLTransportWS/InitPayloads.swift index 41a6cc2..8a50b36 100644 --- a/Sources/GraphQLTransportWS/InitPayloads.swift +++ b/Sources/GraphQLTransportWS/InitPayloads.swift @@ -1,12 +1,12 @@ // Contains convenient `connection_init` payloads for users of this package /// `connection_init` `payload` that is empty -public struct EmptyInitPayload: Equatable & Codable { } +public struct EmptyInitPayload: Equatable & Codable & Sendable {} /// `connection_init` `payload` that includes an `authToken` field -public struct TokenInitPayload: Equatable & Codable { +public struct TokenInitPayload: Equatable & Codable & Sendable { public let authToken: String - + public init(authToken: String) { self.authToken = authToken } diff --git a/Sources/GraphQLTransportWS/JsonEncodable.swift b/Sources/GraphQLTransportWS/JsonEncodable.swift index 51c8673..b54f881 100644 --- a/Sources/GraphQLTransportWS/JsonEncodable.swift +++ b/Sources/GraphQLTransportWS/JsonEncodable.swift @@ -12,8 +12,7 @@ extension JsonEncodable { let data: Data do { data = try encoder.encode(self) - } - catch { + } catch { return EncodingErrorResponse("Unable to encode response").toJSON(encoder) } guard let body = String(data: data, encoding: .utf8) else { diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 7e01402..3a9c157 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -1,23 +1,22 @@ import Foundation -import NIO -/// Protocol for an object that can send and recieve messages. This allows mocking in tests. +/// Protocol for an object that can send and recieve messages. This allows mocking in tests public protocol Messenger: AnyObject { // AnyObject compliance requires that the implementing object is a class and we can reference it weakly - + /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) -> Void where S: Collection, S.Element == Character - + func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character + /// Set the callback that should be run when a message is recieved - func onReceive(callback: @escaping (String) -> Void) -> Void - + func onReceive(callback: @escaping (String) async throws -> Void) + /// Close the messenger - func close() -> Void - + func close() async throws + /// Indicate that the messenger experienced an error. /// - Parameters: /// - message: The message describing the error /// - code: An error code - func error(_ message: String, code: Int) -> Void + func error(_ message: String, code: Int) async throws } diff --git a/Sources/GraphQLTransportWS/Requests.swift b/Sources/GraphQLTransportWS/Requests.swift index 48d474b..98267ca 100644 --- a/Sources/GraphQLTransportWS/Requests.swift +++ b/Sources/GraphQLTransportWS/Requests.swift @@ -42,8 +42,8 @@ enum RequestMessageType: String, Codable { case subscribe case complete case unknown - - public init(from decoder: Decoder) throws { + + init(from decoder: Decoder) throws { guard let value = try? decoder.singleValueContainer().decode(String.self) else { self = .unknown return diff --git a/Sources/GraphQLTransportWS/Responses.swift b/Sources/GraphQLTransportWS/Responses.swift index d8cab1e..34da4fd 100644 --- a/Sources/GraphQLTransportWS/Responses.swift +++ b/Sources/GraphQLTransportWS/Responses.swift @@ -10,9 +10,9 @@ struct Response: Equatable, JsonEncodable { public struct ConnectionAckResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [String: Map]? - + init(_ payload: [String: Map]? = nil) { - self.type = .connectionAck + type = .connectionAck self.payload = payload } } @@ -22,9 +22,9 @@ public struct NextResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: GraphQLResult? public let id: String - + init(_ payload: GraphQLResult? = nil, id: String) { - self.type = .next + type = .next self.payload = payload self.id = id } @@ -34,9 +34,9 @@ public struct NextResponse: Equatable, JsonEncodable { public struct CompleteResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let id: String - + init(id: String) { - self.type = .complete + type = .complete self.id = id } } @@ -46,18 +46,18 @@ public struct ErrorResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [GraphQLError] public let id: String - + init(_ errors: [Error], id: String) { let graphQLErrors = errors.map { error -> GraphQLError in switch error { - case let graphQLError as GraphQLError: - return graphQLError - default: - return GraphQLError(error) + case let graphQLError as GraphQLError: + return graphQLError + default: + return GraphQLError(error) } } - self.type = .error - self.payload = graphQLErrors + type = .error + payload = graphQLErrors self.id = id } } @@ -69,7 +69,7 @@ enum ResponseMessageType: String, Codable { case error case complete case unknown - + init(from decoder: Decoder) throws { guard let value = try? decoder.singleValueContainer().decode(String.self) else { self = .unknown @@ -84,9 +84,9 @@ enum ResponseMessageType: String, Codable { struct EncodingErrorResponse: Equatable, Codable, JsonEncodable { let type: ResponseMessageType let payload: [String: String] - + init(_ errorMessage: String) { - self.type = .error - self.payload = ["error": errorMessage] + type = .error + payload = ["error": errorMessage] } } diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index b18c2ec..74db3d5 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -1,276 +1,270 @@ import Foundation import GraphQL -import GraphQLRxSwift -import NIO -import RxSwift /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server { +public class Server< + InitPayload: Equatable & Codable & Sendable, + SubscriptionSequenceType: AsyncSequence & Sendable +>: @unchecked Sendable where + SubscriptionSequenceType.Element == GraphQLResult +{ // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - let onExecute: (GraphQLRequest) -> EventLoopFuture - let onSubscribe: (GraphQLRequest) -> EventLoopFuture - var auth: (InitPayload) throws -> EventLoopFuture - - var onExit: () -> Void = { } - var onOperationComplete: (String) -> Void = { _ in } - var onOperationError: (String) -> Void = { _ in } - var onMessage: (String) -> Void = { _ in } - + + let onExecute: (GraphQLRequest) async throws -> GraphQLResult + let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType + var auth: (InitPayload) async throws -> Void + + var onExit: () async throws -> Void = {} + var onMessage: (String) async throws -> Void = { _ in } + var onOperationComplete: (String) async throws -> Void = { _ in } + var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } + var initialized = false - - let disposeBag = DisposeBag() - let encoder = GraphQLJSONEncoder() + let decoder = JSONDecoder() - + let encoder = GraphQLJSONEncoder() + + private var subscriptionTasks = [String: Task]() + /// Create a new server /// /// - Parameters: /// - messenger: The messenger to bind the server to. - /// - onExecute: Callback run during `subscribe` resolution for non-streaming queries. Typically this is `API.execute`. - /// - onSubscribe: Callback run during `subscribe` resolution for streaming queries. Typically this is `API.subscribe`. - /// - eventLoop: EventLoop on which to perform server operations. + /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. + /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. public init( messenger: Messenger, - onExecute: @escaping (GraphQLRequest) -> EventLoopFuture, - onSubscribe: @escaping (GraphQLRequest) -> EventLoopFuture, - eventLoop: EventLoop + onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, + onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType ) { self.messenger = messenger self.onExecute = onExecute self.onSubscribe = onSubscribe - self.auth = { _ in eventLoop.makeSucceededVoidFuture() } - + auth = { _ in } + messenger.onReceive { message in - self.onMessage(message) - + guard let messenger = self.messenger else { return } + + try await self.onMessage(message) + // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages return } - - guard let data = message.data(using: .utf8) else { - self.error(.invalidEncoding()) + + guard let json = message.data(using: .utf8) else { + try await self.error(.invalidEncoding()) return } - + let request: Request do { - request = try self.decoder.decode(Request.self, from: data) - } - catch { - self.error(.noType()) + request = try self.decoder.decode(Request.self, from: json) + } catch { + try await self.error(.noType()) return } - + // handle incoming message switch request.type { - case .connectionInit: - guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: data) else { - self.error(.invalidRequestFormat(messageType: .connectionInit)) - return - } - self.onConnectionInit(connectionInitRequest) - case .subscribe: - guard let subscribeRequest = try? self.decoder.decode(SubscribeRequest.self, from: data) else { - self.error(.invalidRequestFormat(messageType: .subscribe)) - return - } - self.onSubscribe(subscribeRequest) - case .complete: - guard let completeRequest = try? self.decoder.decode(CompleteRequest.self, from: data) else { - self.error(.invalidRequestFormat(messageType: .complete)) - return - } - self.onOperationComplete(completeRequest.id) - case .unknown: - self.error(.invalidType()) + case .connectionInit: + guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .connectionInit)) + return + } + try await self.onConnectionInit(connectionInitRequest, messenger) + case .subscribe: + guard let subscribeRequest = try? self.decoder.decode(SubscribeRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .subscribe)) + return + } + try await self.onSubscribe(subscribeRequest) + case .complete: + guard let completeRequest = try? self.decoder.decode(CompleteRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .complete)) + return + } + try await self.onOperationComplete(completeRequest) + case .unknown: + try await self.error(.invalidType()) } } } - - /// Define the callback run during `connection_init` resolution that allows authorization using the `payload`. - /// Throw or fail the future to indicate that authorization has failed. - /// - Parameter callback: The callback to assign - public func auth(_ callback: @escaping (InitPayload) throws -> EventLoopFuture) { - self.auth = callback + + deinit { + subscriptionTasks.values.forEach { $0.cancel() } + } + + /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. + /// Throw from this closure to indicate that authorization has failed. + /// - Parameter callback: The callback to assign + public func auth(_ callback: @escaping (InitPayload) async throws -> Void) { + auth = callback } - + /// Define the callback run when the communication is shut down, either by the client or server /// - Parameter callback: The callback to assign public func onExit(_ callback: @escaping () -> Void) { - self.onExit = callback + onExit = callback } - + /// Define the callback run on receipt of any message /// - Parameter callback: The callback to assign public func onMessage(_ callback: @escaping (String) -> Void) { - self.onMessage = callback + onMessage = callback } - + /// Define the callback run on the completion a full operation (query/mutation, end of subscription) - /// - Parameter callback: The callback to assign, taking a string parameter for the ID of the operation + /// - Parameter callback: The callback to assign public func onOperationComplete(_ callback: @escaping (String) -> Void) { - self.onOperationComplete = callback + onOperationComplete = callback } - + /// Define the callback to run on error of any full operation (failed query, interrupted subscription) - /// - Parameter callback: The callback to assign, taking a string parameter for the ID of the operation - public func onOperationError(_ callback: @escaping (String) -> Void) { - self.onOperationError = callback + /// - Parameter callback: The callback to assign + public func onOperationError(_ callback: @escaping (String, [Error]) -> Void) { + onOperationError = callback } - - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest) { + + private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { guard !initialized else { - self.error(.tooManyInitializations()) + try await error(.tooManyInitializations()) return } - + do { - let authResult = try self.auth(connectionInitRequest.payload) - authResult.whenSuccess { - self.initialized = true - self.sendConnectionAck() - } - authResult.whenFailure { error in - self.error(.unauthorized()) - return - } - } - catch { - self.error(.unauthorized()) + try await auth(connectionInitRequest.payload) + } catch { + try await self.error(.unauthorized()) return } + initialized = true + try await sendConnectionAck() + // TODO: Should we send the `ka` message? } - - private func onSubscribe(_ subscribeRequest: SubscribeRequest) { + + private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws { guard initialized else { - self.error(.notInitialized()) + try await error(.notInitialized()) return } - + let id = subscribeRequest.id + if subscriptionTasks[id] != nil { + try await error(.subscriberAlreadyExists(id: id)) + } + let graphQLRequest = subscribeRequest.payload - + var isStreaming = false do { isStreaming = try graphQLRequest.isSubscription() - } - catch { - self.sendError(error, id: id) + } catch { + try await sendError(error, id: id) return } - + if isStreaming { - let subscribeFuture = onSubscribe(graphQLRequest) - subscribeFuture.whenSuccess { [weak self] result in - guard let self = self else { return } - guard let streamOpt = result.stream else { - // API issue - subscribe resolver isn't stream - self.sendError(result.errors, id: id) - return - } - let stream = streamOpt as! ObservableSubscriptionEventStream - let observable = stream.observable - - observable.subscribe( - onNext: { [weak self] resultFuture in - guard let self = self else { return } - resultFuture.whenSuccess { result in - self.sendNext(result, id: id) - } - resultFuture.whenFailure { error in - self.sendError(error, id: id) - } - }, - onError: { [weak self] error in - guard let self = self else { return } - self.sendError(error, id: id) - }, - onCompleted: { [weak self] in - guard let self = self else { return } - self.sendComplete(id: id) + subscriptionTasks[id] = Task { + do { + let stream = try await onSubscribe(graphQLRequest) + for try await event in stream { + try Task.checkCancellation() + try await self.sendNext(event, id: id) } - ).disposed(by: self.disposeBag) + } catch { + try await sendError(error, id: id) + subscriptionTasks.removeValue(forKey: id) + throw error + } + try await self.sendComplete(id: id) + subscriptionTasks.removeValue(forKey: id) } - subscribeFuture.whenFailure { error in - self.sendError(error, id: id) + } else { + do { + let result = try await onExecute(graphQLRequest) + try await sendNext(result, id: id) + try await sendComplete(id: id) + } catch { + try await sendError(error, id: id) } + try await messenger?.close() } - else { - let executeFuture = onExecute(graphQLRequest) - executeFuture.whenSuccess { result in - self.sendNext(result, id: id) - self.sendComplete(id: id) - self.messenger?.close() - } - executeFuture.whenFailure { error in - self.sendError(error, id: id) - self.sendComplete(id: id) - self.messenger?.close() - } + } + + private func onOperationComplete(_ completeRequest: CompleteRequest) async throws { + guard initialized else { + try await error(.notInitialized()) + return + } + + let id = completeRequest.id + if let task = subscriptionTasks[id] { + task.cancel() + subscriptionTasks.removeValue(forKey: id) } + try await onOperationComplete(id) } - + /// Send a `connection_ack` response through the messenger - private func sendConnectionAck(_ payload: [String: Map]? = nil) { + private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionAckResponse(payload).toJSON(encoder) ) } - + /// Send a `next` response through the messenger - private func sendNext(_ payload: GraphQLResult? = nil, id: String) { + private func sendNext(_ payload: GraphQLResult? = nil, id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( NextResponse( payload, id: id ).toJSON(encoder) ) } - + /// Send a `complete` response through the messenger - private func sendComplete(id: String) { + private func sendComplete(id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( CompleteResponse( id: id ).toJSON(encoder) ) - self.onOperationComplete(id) + try await onOperationComplete(id) } - + /// Send an `error` response through the messenger - private func sendError(_ errors: [Error], id: String) { + private func sendError(_ errors: [Error], id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ErrorResponse( errors, id: id ).toJSON(encoder) ) - self.onOperationError(id) + try await onOperationError(id, errors) } - + /// Send an `error` response through the messenger - private func sendError(_ error: Error, id: String) { - self.sendError([error], id: id) + private func sendError(_ error: Error, id: String) async throws { + try await sendError([error], id: id) } - + /// Send an `error` response through the messenger - private func sendError(_ errorMessage: String, id: String) { - self.sendError(GraphQLError(message: errorMessage), id: id) + private func sendError(_ errorMessage: String, id: String) async throws { + try await sendError(GraphQLError(message: errorMessage), id: id) } - + /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLTransportWSError) { + private func error(_ error: GraphQLTransportWSError) async throws { guard let messenger = messenger else { return } - messenger.error(error.message, code: error.code.rawValue) + try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 72f067d..e26de4c 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -1,7 +1,6 @@ import Foundation import GraphQL -import NIO import XCTest @testable import GraphQLTransportWS @@ -9,229 +8,220 @@ import XCTest class GraphqlTransportWSTests: XCTestCase { var clientMessenger: TestMessenger! var serverMessenger: TestMessenger! - var server: Server! - var eventLoop: EventLoop! - + var server: Server>! + var context: TestContext! + var subscribeReady: Bool! = false + override func setUp() { // Point the client and server at each other clientMessenger = TestMessenger() serverMessenger = TestMessenger() clientMessenger.other = serverMessenger serverMessenger.other = clientMessenger - - eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() + let api = TestAPI() let context = TestContext() - - server = Server( + + server = .init( messenger: serverMessenger, onExecute: { graphQLRequest in - api.execute( + try await api.execute( request: graphQLRequest.query, - context: context, - on: self.eventLoop + context: context ) }, onSubscribe: { graphQLRequest in - api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: context, - on: self.eventLoop - ) - }, - eventLoop: self.eventLoop + context: context + ).get() + self.subscribeReady = true + return subscription + } ) + self.context = context } - + /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() throws { - var messages = [String]() - let completeExpectation = XCTestExpectation() - + func testInitialize() async throws { let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onMessage { message, _ in + continuation.yield(message) + // Expect only one message + continuation.finish() + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } } - - client.sendStart( + + try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: UUID().uuidString ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages, ["\(ErrorCode.notInitialized): Connection not initialized"] ) } - + /// Tests that throwing in the authorization callback forces an unauthorized error - func testAuthWithThrow() throws { - server.auth { payload in + func testAuthWithThrow() async throws { + server.auth { _ in throw TestError.couldBeAnything } - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onMessage { message, _ in + continuation.yield(message) + // Expect only one message + continuation.finish() + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) - XCTAssertEqual( - messages, - ["\(ErrorCode.unauthorized): Unauthorized"] - ) - } - - /// Tests that failing a future in the authorization callback forces an unauthorized error - func testAuthWithFailedFuture() throws { - server.auth { payload in - self.eventLoop.makeFailedFuture(TestError.couldBeAnything) - } - - var messages = [String]() - let completeExpectation = XCTestExpectation() - - let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) } - - client.sendConnectionInit( - payload: TokenInitPayload( - authToken: "" - ) - ) - - wait(for: [completeExpectation], timeout: 2) XCTAssertEqual( messages, ["\(ErrorCode.unauthorized): Unauthorized"] ) } - + /// Tests a single-op conversation - func testSingleOp() throws { + func testSingleOp() async throws { let id = UUID().description - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + let client = Client(messenger: clientMessenger) - client.onConnectionAck { _, client in - client.sendStart( - payload: GraphQLRequest( - query: """ - query { - hello - } - """ - ), - id: id - ) - } - client.onError { _, _ in - completeExpectation.fulfill() - } - client.onComplete { _, _ in - completeExpectation.fulfill() - } - client.onMessage { message, _ in - messages.append(message) + let messageStream = AsyncThrowingStream { continuation in + client.onConnectionAck { _, client in + try await client.sendStart( + payload: GraphQLRequest( + query: """ + query { + hello + } + """ + ), + id: id + ) + } + client.onMessage { message, _ in + continuation.yield(message) + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } + client.onComplete { _, _ in + continuation.finish() + } } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages.count, 3, // 1 connection_ack, 1 next, 1 complete "Messages: \(messages.description)" ) } - + /// Tests a streaming conversation - func testStreaming() throws { + func testStreaming() async throws { let id = UUID().description - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + var dataIndex = 1 let dataIndexMax = 3 - + let client = Client(messenger: clientMessenger) - client.onConnectionAck { _, client in - client.sendStart( - payload: GraphQLRequest( - query: """ - subscription { - hello - } - """ - ), - id: id - ) - - // Short sleep to allow for server to register subscription - usleep(3000) - - pubsub.onNext("hello \(dataIndex)") - } - client.onNext { _, _ in - dataIndex = dataIndex + 1 - if dataIndex <= dataIndexMax { - pubsub.onNext("hello \(dataIndex)") - } else { - pubsub.onCompleted() + let messageStream = AsyncThrowingStream { continuation in + client.onConnectionAck { _, client in + try await client.sendStart( + payload: GraphQLRequest( + query: """ + subscription { + hello + } + """ + ), + id: id + ) + + // Wait until server has registered subscription + var i = 0 + while !self.subscribeReady, i < 50 { + usleep(1000) + i = i + 1 + } + if i == 50 { + XCTFail("Subscription timeout: Took longer than 50ms to set up") + } + + self.context.publisher.emit(event: "hello \(dataIndex)") + } + client.onNext { _, _ in + dataIndex = dataIndex + 1 + if dataIndex <= dataIndexMax { + self.context.publisher.emit(event: "hello \(dataIndex)") + } else { + self.context.publisher.cancel() + } + } + client.onMessage { message, _ in + continuation.yield(message) + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } + client.onComplete { _, _ in + continuation.finish() } } - client.onError { _, _ in - completeExpectation.fulfill() - } - client.onComplete { _, _ in - completeExpectation.fulfill() - } - client.onMessage { message, _ in - messages.append(message) - } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages.count, 5, // 1 connection_ack, 3 next, 1 complete "Messages: \(messages.description)" ) } - + enum TestError: Error { case couldBeAnything } diff --git a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift index 4b7ea03..8867da1 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift @@ -1,15 +1,11 @@ import Foundation -import GraphQL import Graphiti -import GraphQLRxSwift -import RxSwift - -let pubsub = PublishSubject() +import GraphQL struct TestAPI: API { let resolver = TestResolver() let context = TestContext() - + let schema = try! Schema { Query { Field("hello", at: TestResolver.hello) @@ -20,7 +16,9 @@ struct TestAPI: API { } } -final class TestContext { +final class TestContext: Sendable { + let publisher = SimplePubSub() + func hello() -> String { "world" } @@ -30,8 +28,48 @@ struct TestResolver { func hello(context: TestContext, arguments _: NoArguments) -> String { context.hello() } - - func subscribeHello(context: TestContext, arguments: NoArguments) -> EventStream { - pubsub.toEventStream() + + func subscribeHello(context: TestContext, arguments _: NoArguments) -> AsyncThrowingStream { + context.publisher.subscribe() } } + +/// A very simple publish/subscriber used for testing +class SimplePubSub: @unchecked Sendable { + private var subscribers: [Subscriber] + + init() { + subscribers = [] + } + + func emit(event: T) { + for subscriber in subscribers { + subscriber.callback(event) + } + } + + func cancel() { + for subscriber in subscribers { + subscriber.cancel() + } + } + + func subscribe() -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + let subscriber = Subscriber( + callback: { newValue in + continuation.yield(newValue) + }, + cancel: { + continuation.finish() + } + ) + subscribers.append(subscriber) + } + } +} + +struct Subscriber { + let callback: (T) -> Void + let cancel: () -> Void +} diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index 1e7c274..a35aa09 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -7,32 +7,28 @@ import Foundation /// /// Note that this only retains a weak reference to 'other', so the client should retain references /// or risk them being deinitialized early -class TestMessenger: Messenger { +class TestMessenger: Messenger, @unchecked Sendable { weak var other: TestMessenger? - var onReceive: (String) -> Void = { _ in } + var onReceive: (String) async throws -> Void = { _ in } let queue: DispatchQueue = .init(label: "Test messenger") - + init() {} - - func send(_ message: S) where S: Collection, S.Element == Character { + + func send(_ message: S) async throws where S: Collection, S.Element == Character { guard let other = other else { return } - - // Run the other message asyncronously to avoid nesting issues - queue.async { - other.onReceive(String(message)) - } + try await other.onReceive(String(message)) } - - func onReceive(callback: @escaping (String) -> Void) { - self.onReceive = callback + + func onReceive(callback: @escaping (String) async throws -> Void) { + onReceive = callback } - - func error(_ message: String, code: Int) { - self.send("\(code): \(message)") + + func error(_ message: String, code: Int) async throws { + try await send("\(code): \(message)") } - + func close() { // This is a testing no-op }