Skip to content

Commit 18445ca

Browse files
Improved memory management in client callbacks
By adding an additional parameter to the callbacks, the user no longer has to ensure that the references are weak (we manage it internally). This is based on WebSocket implementation.
1 parent 4d54176 commit 18445ca

File tree

2 files changed

+31
-33
lines changed

2 files changed

+31
-33
lines changed

Sources/GraphQLWS/Client.swift

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ import GraphQL
88
class Client {
99
let messenger: Messenger
1010

11-
var onConnectionError: (ConnectionErrorResponse) -> Void = { _ in }
12-
var onConnectionAck: (ConnectionAckResponse) -> Void = { _ in }
13-
var onConnectionKeepAlive: (ConnectionKeepAliveResponse) -> Void = { _ in }
14-
var onData: (DataResponse) -> Void = { _ in }
15-
var onError: (ErrorResponse) -> Void = { _ in }
16-
var onComplete: (CompleteResponse) -> Void = { _ in }
17-
var onMessage: (String) -> Void = { _ in }
11+
var onConnectionError: (ConnectionErrorResponse, Client) -> Void = { _, _ in }
12+
var onConnectionAck: (ConnectionAckResponse, Client) -> Void = { _, _ in }
13+
var onConnectionKeepAlive: (ConnectionKeepAliveResponse, Client) -> Void = { _, _ in }
14+
var onData: (DataResponse, Client) -> Void = { _, _ in }
15+
var onError: (ErrorResponse, Client) -> Void = { _, _ in }
16+
var onComplete: (CompleteResponse, Client) -> Void = { _, _ in }
17+
var onMessage: (String, Client) -> Void = { _, _ in }
1818

1919
let encoder = GraphQLJSONEncoder()
2020
let decoder = JSONDecoder()
@@ -31,7 +31,7 @@ class Client {
3131
self.messenger.onRecieve { [weak self] message in
3232
guard let self = self else { return }
3333

34-
self.onMessage(message)
34+
self.onMessage(message, self)
3535

3636
// Detect and ignore error responses.
3737
if message.starts(with: "44") {
@@ -62,42 +62,42 @@ class Client {
6262
self.messenger.error(error.message, code: error.code)
6363
return
6464
}
65-
self.onConnectionError(connectionErrorResponse)
65+
self.onConnectionError(connectionErrorResponse, self)
6666
case .GQL_CONNECTION_ACK:
6767
guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else {
6868
let error = GraphQLWSError.invalidResponseFormat(messageType: .GQL_CONNECTION_ACK)
6969
self.messenger.error(error.message, code: error.code)
7070
return
7171
}
72-
self.onConnectionAck(connectionAckResponse)
72+
self.onConnectionAck(connectionAckResponse, self)
7373
case .GQL_CONNECTION_KEEP_ALIVE:
7474
guard let connectionKeepAliveResponse = try? self.decoder.decode(ConnectionKeepAliveResponse.self, from: json) else {
7575
let error = GraphQLWSError.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)
7676
self.messenger.error(error.message, code: error.code)
7777
return
7878
}
79-
self.onConnectionKeepAlive(connectionKeepAliveResponse)
79+
self.onConnectionKeepAlive(connectionKeepAliveResponse, self)
8080
case .GQL_DATA:
8181
guard let nextResponse = try? self.decoder.decode(DataResponse.self, from: json) else {
8282
let error = GraphQLWSError.invalidResponseFormat(messageType: .GQL_DATA)
8383
self.messenger.error(error.message, code: error.code)
8484
return
8585
}
86-
self.onData(nextResponse)
86+
self.onData(nextResponse, self)
8787
case .GQL_ERROR:
8888
guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else {
8989
let error = GraphQLWSError.invalidResponseFormat(messageType: .GQL_ERROR)
9090
self.messenger.error(error.message, code: error.code)
9191
return
9292
}
93-
self.onError(errorResponse)
93+
self.onError(errorResponse, self)
9494
case .GQL_COMPLETE:
9595
guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else {
9696
let error = GraphQLWSError.invalidResponseFormat(messageType: .GQL_COMPLETE)
9797
self.messenger.error(error.message, code: error.code)
9898
return
9999
}
100-
self.onComplete(completeResponse)
100+
self.onComplete(completeResponse, self)
101101
case .unknown:
102102
let error = GraphQLWSError.invalidType()
103103
self.messenger.error(error.message, code: error.code)
@@ -107,43 +107,43 @@ class Client {
107107

108108
/// Define the callback run on receipt of a `connection_error` message
109109
/// - Parameter callback: The callback to assign
110-
func onConnectionError(_ callback: @escaping (ConnectionErrorResponse) -> Void) {
110+
func onConnectionError(_ callback: @escaping (ConnectionErrorResponse, Client) -> Void) {
111111
self.onConnectionError = callback
112112
}
113113

114114
/// Define the callback run on receipt of a `connection_ack` message
115115
/// - Parameter callback: The callback to assign
116-
func onConnectionAck(_ callback: @escaping (ConnectionAckResponse) -> Void) {
116+
func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) -> Void) {
117117
self.onConnectionAck = callback
118118
}
119119

120120
/// Define the callback run on receipt of a `connection_ka` message
121121
/// - Parameter callback: The callback to assign
122-
func onConnectionKeepAlive(_ callback: @escaping (ConnectionKeepAliveResponse) -> Void) {
122+
func onConnectionKeepAlive(_ callback: @escaping (ConnectionKeepAliveResponse, Client) -> Void) {
123123
self.onConnectionKeepAlive = callback
124124
}
125125

126126
/// Define the callback run on receipt of a `data` message
127127
/// - Parameter callback: The callback to assign
128-
func onData(_ callback: @escaping (DataResponse) -> Void) {
128+
func onData(_ callback: @escaping (DataResponse, Client) -> Void) {
129129
self.onData = callback
130130
}
131131

132132
/// Define the callback run on receipt of an `error` message
133133
/// - Parameter callback: The callback to assign
134-
func onError(_ callback: @escaping (ErrorResponse) -> Void) {
134+
func onError(_ callback: @escaping (ErrorResponse, Client) -> Void) {
135135
self.onError = callback
136136
}
137137

138138
/// Define the callback run on receipt of any message
139139
/// - Parameter callback: The callback to assign
140-
func onComplete(_ callback: @escaping (CompleteResponse) -> Void) {
140+
func onComplete(_ callback: @escaping (CompleteResponse, Client) -> Void) {
141141
self.onComplete = callback
142142
}
143143

144144
/// Define the callback run on receipt of a `complete` message
145145
/// - Parameter callback: The callback to assign
146-
func onMessage(_ callback: @escaping (String) -> Void) {
146+
func onMessage(_ callback: @escaping (String, Client) -> Void) {
147147
self.onMessage = callback
148148
}
149149

Tests/GraphQLWSTests/GraphQLWSTests.swift

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class GraphqlWsTests: XCTestCase {
5252
let completeExpectation = XCTestExpectation()
5353

5454
let client = Client(messenger: clientMessenger)
55-
client.onMessage { message in
55+
client.onMessage { message, _ in
5656
messages.append(message)
5757
completeExpectation.fulfill()
5858
}
@@ -79,8 +79,7 @@ class GraphqlWsTests: XCTestCase {
7979

8080
let client = Client(messenger: clientMessenger)
8181

82-
client.onConnectionAck { [weak client] _ in
83-
guard let client = client else { return }
82+
client.onConnectionAck { _, client in
8483
client.sendStart(
8584
payload: GraphQLRequest(
8685
query: """
@@ -92,13 +91,13 @@ class GraphqlWsTests: XCTestCase {
9291
id: id
9392
)
9493
}
95-
client.onError { _ in
94+
client.onError { _, _ in
9695
completeExpectation.fulfill()
9796
}
98-
client.onComplete { _ in
97+
client.onComplete { _, _ in
9998
completeExpectation.fulfill()
10099
}
101-
client.onMessage { message in
100+
client.onMessage { message, _ in
102101
messages.append(message)
103102
}
104103

@@ -123,8 +122,7 @@ class GraphqlWsTests: XCTestCase {
123122
let dataIndexMax = 3
124123

125124
let client = Client(messenger: clientMessenger)
126-
client.onConnectionAck { [weak client] _ in
127-
guard let client = client else { return }
125+
client.onConnectionAck { _, client in
128126
client.sendStart(
129127
payload: GraphQLRequest(
130128
query: """
@@ -141,21 +139,21 @@ class GraphqlWsTests: XCTestCase {
141139

142140
pubsub.onNext("hello \(dataIndex)")
143141
}
144-
client.onData { _ in
142+
client.onData { _, _ in
145143
dataIndex = dataIndex + 1
146144
if dataIndex <= dataIndexMax {
147145
pubsub.onNext("hello \(dataIndex)")
148146
} else {
149147
pubsub.onCompleted()
150148
}
151149
}
152-
client.onError { _ in
150+
client.onError { _, _ in
153151
completeExpectation.fulfill()
154152
}
155-
client.onComplete { _ in
153+
client.onComplete { _, _ in
156154
completeExpectation.fulfill()
157155
}
158-
client.onMessage { message in
156+
client.onMessage { message, _ in
159157
messages.append(message)
160158
}
161159

0 commit comments

Comments
 (0)