Skip to content

Commit e84804f

Browse files
feature: Makes connection_init payload type generic
This allows the user to define their own authorization types, but still allows us to encode/decode in a strongly typed way. Some common default payload types are included.
1 parent 7b27770 commit e84804f

File tree

5 files changed

+31
-30
lines changed

5 files changed

+31
-30
lines changed

Sources/GraphQLWS/Client.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Foundation
44
import GraphQL
55

66
/// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose.
7-
public class Client {
7+
public class Client<InitPayload: Equatable & Codable> {
88
// We keep this weak because we strongly inject this object into the messenger callback
99
weak var messenger: Messenger?
1010

@@ -136,7 +136,7 @@ public class Client {
136136
}
137137

138138
/// Send a `connection_init` request through the messenger
139-
public func sendConnectionInit(payload: ConnectionInitAuth?) {
139+
public func sendConnectionInit(payload: InitPayload) {
140140
guard let messenger = messenger else { return }
141141
messenger.send(
142142
ConnectionInitRequest(

Sources/GraphQLWS/InitPayloads.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Contains convenient `connection_init` payloads for users of this package
2+
3+
/// `connection_init` `payload` that is empty
4+
public struct EmptyInitPayload: Equatable & Codable { }
5+
6+
/// `connection_init` `payload` that includes an `authToken` field
7+
public struct TokenInitPayload: Equatable & Codable {
8+
let authToken: String
9+
}

Sources/GraphQLWS/Requests.swift

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,9 @@ struct Request: Equatable, JsonEncodable {
2020
}
2121

2222
/// A websocket `connection_init` request from the client to the server
23-
public struct ConnectionInitRequest: Equatable, JsonEncodable {
23+
public struct ConnectionInitRequest<InitPayload: Codable & Equatable>: Equatable, JsonEncodable {
2424
var type = RequestMessageType.GQL_CONNECTION_INIT
25-
public let payload: ConnectionInitAuth?
26-
}
27-
28-
// TODO: Make this structure user-defined
29-
/// Authorization format for a websocket `connection_init` request from the client to the server
30-
public struct ConnectionInitAuth: Equatable, JsonEncodable {
31-
public let authToken: String
32-
33-
public init(authToken: String) {
34-
self.authToken = authToken
35-
}
25+
let payload: InitPayload
3626
}
3727

3828
/// A websocket `start` request from the client to the server

Sources/GraphQLWS/Server.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ import NIO
77
import RxSwift
88

99
/// Server implements the server-side portion of the protocol, allowing a few callbacks for customization.
10-
public class Server {
10+
public class Server<InitPayload: Equatable & Codable> {
1111
// We keep this weak because we strongly inject this object into the messenger callback
1212
weak var messenger: Messenger?
1313

1414
let onExecute: (GraphQLRequest) -> EventLoopFuture<GraphQLResult>
1515
let onSubscribe: (GraphQLRequest) -> EventLoopFuture<SubscriptionResult>
1616

17-
var auth: (ConnectionInitRequest) throws -> Void = { _ in }
17+
var auth: (InitPayload) throws -> Void = { _ in }
1818
var onExit: () -> Void = { }
1919
var onMessage: (String) -> Void = { _ in }
2020

@@ -66,7 +66,7 @@ public class Server {
6666

6767
switch request.type {
6868
case .GQL_CONNECTION_INIT:
69-
guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else {
69+
guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest<InitPayload>.self, from: json) else {
7070
self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT))
7171
return
7272
}
@@ -96,9 +96,9 @@ public class Server {
9696
}
9797

9898
/// Define the callback run during `connection_init` resolution that allows authorization using the `payload`.
99-
/// Throw to indicate that authorization has failed.
99+
/// Throw from this closure to indicate that authorization has failed.
100100
/// - Parameter callback: The callback to assign
101-
public func auth(_ callback: @escaping (ConnectionInitRequest) throws -> Void) {
101+
public func auth(_ callback: @escaping (InitPayload) throws -> Void) {
102102
self.auth = callback
103103
}
104104

@@ -114,14 +114,14 @@ public class Server {
114114
self.onMessage = callback
115115
}
116116

117-
private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _ messenger: Messenger) {
117+
private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest<InitPayload>, _ messenger: Messenger) {
118118
guard !initialized else {
119119
self.error(.tooManyInitializations())
120120
return
121121
}
122122

123123
do {
124-
try self.auth(connectionInitRequest)
124+
try self.auth(connectionInitRequest.payload)
125125
}
126126
catch {
127127
self.error(.unauthorized())

Tests/GraphQLWSTests/GraphQLWSTests.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import XCTest
1111
class GraphqlWsTests: XCTestCase {
1212
var clientMessenger: TestMessenger!
1313
var serverMessenger: TestMessenger!
14-
var server: Server!
14+
var server: Server<TokenInitPayload>!
1515

1616
override func setUp() {
1717
clientMessenger = TestMessenger()
@@ -24,7 +24,7 @@ class GraphqlWsTests: XCTestCase {
2424
let api = TestAPI()
2525
let context = TestContext()
2626

27-
server = Server(
27+
server = Server<TokenInitPayload>(
2828
messenger: serverMessenger,
2929
onExecute: { graphQLRequest in
3030
api.execute(
@@ -48,7 +48,7 @@ class GraphqlWsTests: XCTestCase {
4848
var messages = [String]()
4949
let completeExpectation = XCTestExpectation()
5050

51-
let client = Client(messenger: clientMessenger)
51+
let client = Client<TokenInitPayload>(messenger: clientMessenger)
5252
client.onMessage { message, _ in
5353
messages.append(message)
5454
completeExpectation.fulfill()
@@ -81,14 +81,14 @@ class GraphqlWsTests: XCTestCase {
8181
var messages = [String]()
8282
let completeExpectation = XCTestExpectation()
8383

84-
let client = Client(messenger: clientMessenger)
84+
let client = Client<TokenInitPayload>(messenger: clientMessenger)
8585
client.onMessage { message, _ in
8686
messages.append(message)
8787
completeExpectation.fulfill()
8888
}
8989

9090
client.sendConnectionInit(
91-
payload: ConnectionInitAuth(
91+
payload: TokenInitPayload(
9292
authToken: ""
9393
)
9494
)
@@ -100,14 +100,15 @@ class GraphqlWsTests: XCTestCase {
100100
)
101101
}
102102

103+
/// Test single op message flow works as expected
103104
func testSingleOp() throws {
104105
let id = UUID().description
105106

106107
// Test single-op conversation
107108
var messages = [String]()
108109
let completeExpectation = XCTestExpectation()
109110

110-
let client = Client(messenger: clientMessenger)
111+
let client = Client<TokenInitPayload>(messenger: clientMessenger)
111112

112113
client.onConnectionAck { _, client in
113114
client.sendStart(
@@ -131,7 +132,7 @@ class GraphqlWsTests: XCTestCase {
131132
messages.append(message)
132133
}
133134

134-
client.sendConnectionInit(payload: ConnectionInitAuth(authToken: ""))
135+
client.sendConnectionInit(payload: TokenInitPayload(authToken: ""))
135136

136137
wait(for: [completeExpectation], timeout: 2)
137138
XCTAssertEqual(
@@ -141,6 +142,7 @@ class GraphqlWsTests: XCTestCase {
141142
)
142143
}
143144

145+
/// Test streaming message flow works as expected
144146
func testStreaming() throws {
145147
let id = UUID().description
146148

@@ -151,7 +153,7 @@ class GraphqlWsTests: XCTestCase {
151153
var dataIndex = 1
152154
let dataIndexMax = 3
153155

154-
let client = Client(messenger: clientMessenger)
156+
let client = Client<TokenInitPayload>(messenger: clientMessenger)
155157
client.onConnectionAck { _, client in
156158
client.sendStart(
157159
payload: GraphQLRequest(
@@ -187,7 +189,7 @@ class GraphqlWsTests: XCTestCase {
187189
messages.append(message)
188190
}
189191

190-
client.sendConnectionInit(payload: ConnectionInitAuth(authToken: ""))
192+
client.sendConnectionInit(payload: TokenInitPayload(authToken: ""))
191193

192194
wait(for: [completeExpectation], timeout: 2)
193195
XCTAssertEqual(

0 commit comments

Comments
 (0)