@@ -28,7 +28,7 @@ import Foundation
28
28
///
29
29
/// This mainly comes into play when we don't want to block developers from sending messages while a
30
30
/// session is being reloaded.
31
- @available ( iOS 15 . 0 , macOS 12 . 0 , macCatalyst 15 . 0 , tvOS 15 . 0 , watchOS 8 . 0 , * )
31
+ @available ( iOS 15 . 0 , macOS 12 . 0 , macCatalyst 15 . 0 , tvOS 15 . 0 , * )
32
32
@available ( watchOS, unavailable)
33
33
actor LiveSessionService {
34
34
let responses : AsyncThrowingStream < LiveServerMessage , Error >
@@ -54,11 +54,6 @@ actor LiveSessionService {
54
54
private let jsonEncoder = JSONEncoder ( )
55
55
private let jsonDecoder = JSONDecoder ( )
56
56
57
- /// Task that doesn't complete until the server sends a setupComplete message.
58
- ///
59
- /// Used to hold off on sending messages until the server is ready.
60
- private var setupTask : Task < Void , Error >
61
-
62
57
/// Long running task that that wraps around the websocket, propogating messages through the
63
58
/// public stream.
64
59
private var responsesTask : Task < Void , Never > ?
@@ -87,11 +82,9 @@ actor LiveSessionService {
87
82
self . toolConfig = toolConfig
88
83
self . systemInstruction = systemInstruction
89
84
self . requestOptions = requestOptions
90
- setupTask = Task { }
91
85
}
92
86
93
87
deinit {
94
- setupTask. cancel ( )
95
88
responsesTask? . cancel ( )
96
89
messageQueueTask? . cancel ( )
97
90
webSocket? . disconnect ( )
@@ -114,29 +107,20 @@ actor LiveSessionService {
114
107
///
115
108
/// Seperated into its own function to make it easier to surface a way to call it seperately when
116
109
/// resuming the same session.
110
+ ///
111
+ /// This function will yield until the websocket is ready to communicate with the client.
117
112
func connect( ) async throws {
118
113
close ( )
119
- // we launch the setup task in a seperate task to allow us to cancel it via close
120
- setupTask = Task { [ weak self] in
121
- // we need a continuation to surface that the setup is complete, while still allowing us to
122
- // listen to the server
123
- try await withCheckedThrowingContinuation { setupContinuation in
124
- // nested task so we can use await
125
- Task { [ weak self] in
126
- guard let self else { return }
127
- await self . listenToServer ( setupContinuation)
128
- }
129
- }
130
- }
131
114
132
- try await setupTask. value
115
+ let stream = try await setupWebsocket ( )
116
+ try await waitForSetupComplete ( stream: stream)
117
+ spawnMessageTasks ( stream: stream)
133
118
}
134
119
135
120
/// Cancel any running tasks and close the websocket.
136
121
///
137
122
/// This method is idempotent; if it's already ran once, it will effectively be a no-op.
138
123
func close( ) {
139
- setupTask. cancel ( )
140
124
responsesTask? . cancel ( )
141
125
messageQueueTask? . cancel ( )
142
126
webSocket? . disconnect ( )
@@ -146,38 +130,19 @@ actor LiveSessionService {
146
130
messageQueueTask = nil
147
131
}
148
132
149
- /// Start a fresh websocket to the backend, and listen for responses .
133
+ /// Performs the initial setup procedure for the model .
150
134
///
151
- /// Will hold off on sending any messages until the server sends a setupComplete message.
135
+ /// The setup procedure with the model follows the procedure:
152
136
///
153
- /// Will also close out the old websocket and the previous long running tasks.
154
- private func listenToServer( _ setupComplete: CheckedContinuation < Void , any Error > ) async {
155
- do {
156
- webSocket = try await createWebsocket ( )
157
- } catch {
158
- let error = LiveSessionSetupError ( underlyingError: error)
159
- close ( )
160
- setupComplete. resume ( throwing: error)
161
- return
162
- }
163
-
137
+ /// - Client sends `BidiGenerateContentSetup`
138
+ /// - Server sends back `BidiGenerateContentSetupComplete` when it's ready
139
+ ///
140
+ /// This function will yield until the setup is complete.
141
+ private func waitForSetupComplete( stream: MappedStream <
142
+ URLSessionWebSocketTask . Message ,
143
+ Data
144
+ > ) async throws {
164
145
guard let webSocket else { return }
165
- let stream = webSocket. connect ( )
166
-
167
- var resumed = false
168
-
169
- // remove the uncommon (and unexpected) responses from the stream, to make normal path cleaner
170
- let dataStream = stream. compactMap { ( message: URLSessionWebSocketTask . Message ) -> Data ? in
171
- switch message {
172
- case let . string( string) :
173
- AILog . error ( code: . liveSessionUnexpectedResponse, " Unexpected string response: \( string) " )
174
- case let . data( data) :
175
- return data
176
- @unknown default :
177
- AILog . error ( code: . liveSessionUnexpectedResponse, " Unknown message received: \( message) " )
178
- }
179
- return nil
180
- }
181
146
182
147
do {
183
148
let setup = BidiGenerateContentSetup (
@@ -194,54 +159,80 @@ actor LiveSessionService {
194
159
} catch {
195
160
let error = LiveSessionSetupError ( underlyingError: error)
196
161
close ( )
197
- setupComplete. resume ( throwing: error)
198
- return
162
+ throw error
199
163
}
200
164
201
- responsesTask = Task {
202
- do {
203
- for try await message in dataStream {
204
- let response : BidiGenerateContentServerMessage
205
- do {
206
- response = try jsonDecoder. decode (
207
- BidiGenerateContentServerMessage . self,
208
- from: message
209
- )
210
- } catch {
211
- // only log the json if it wasn't a decoding error, but an unsupported message type
212
- if error is InvalidMessageTypeError {
213
- AILog . error (
214
- code: . liveSessionUnsupportedMessage,
215
- " The server sent a message that we don't currently have a mapping for. "
216
- )
165
+ do {
166
+ for try await message in stream {
167
+ let response = try decodeServerMessage ( message)
168
+ if case . setupComplete = response. messageType {
169
+ break
170
+ } else {
171
+ AILog . error (
172
+ code: . liveSessionUnexpectedResponse,
173
+ " The model sent us a message that wasn't a setup complete: \( response) "
174
+ )
175
+ }
176
+ }
177
+ } catch {
178
+ if let error = mapWebsocketError ( error) {
179
+ close ( )
180
+ throw error
181
+ }
182
+ }
183
+ }
217
184
218
- AILog . debug (
219
- code: . liveSessionUnsupportedMessagePayload,
220
- message. encodeToJsonString ( ) ?? " \( message) "
221
- )
222
- }
185
+ /// Performs the initial setup procedure for a websocket.
186
+ ///
187
+ /// This includes creating the websocket url and connecting it.
188
+ ///
189
+ /// - Returns: A stream of `Data` frames from the websocket.
190
+ private func setupWebsocket( ) async throws
191
+ -> MappedStream < URLSessionWebSocketTask . Message , Data > {
192
+ do {
193
+ let webSocket = try await createWebsocket ( )
194
+ self . webSocket = webSocket
195
+
196
+ let stream = webSocket. connect ( )
197
+
198
+ // remove the uncommon (and unexpected) frames from the stream, to make normal path cleaner
199
+ return stream. compactMap { message in
200
+ switch message {
201
+ case let . string( string) :
202
+ AILog . error ( code: . liveSessionUnexpectedResponse, " Unexpected string response: \( string) " )
203
+ case let . data( data) :
204
+ return data
205
+ @unknown default :
206
+ AILog . error ( code: . liveSessionUnexpectedResponse, " Unknown message received: \( message) " )
207
+ }
208
+ return nil
209
+ }
210
+ } catch {
211
+ let error = LiveSessionSetupError ( underlyingError: error)
212
+ close ( )
213
+ throw error
214
+ }
215
+ }
223
216
224
- let error = LiveSessionUnsupportedMessageError ( underlyingError: error)
225
- // if we've already finished setting up, then only surface the error through responses
226
- // otherwise, make the setup task error as well
227
- if !resumed {
228
- setupComplete. resume ( throwing: error)
229
- }
230
- throw error
231
- }
217
+ /// Spawn tasks for interacting with the model.
218
+ ///
219
+ /// The following tasks will be spawned:
220
+ ///
221
+ /// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
222
+ /// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
223
+ private func spawnMessageTasks( stream: MappedStream < URLSessionWebSocketTask . Message , Data > ) {
224
+ guard let webSocket else { return }
225
+
226
+ responsesTask = Task {
227
+ do {
228
+ for try await message in stream {
229
+ let response = try decodeServerMessage ( message)
232
230
233
231
if case . setupComplete = response. messageType {
234
- if resumed {
235
- AILog . debug (
236
- code: . duplicateLiveSessionSetupComplete,
237
- " Setup complete was received multiple times; this may be a bug in the model. "
238
- )
239
- } else {
240
- // calling resume multiple times is an error in swift, so we catch multiple calls
241
- // to avoid causing any issues due to model quirks
242
- resumed = true
243
- setupComplete. resume ( )
244
- }
232
+ AILog . debug (
233
+ code: . duplicateLiveSessionSetupComplete,
234
+ " Setup complete was received multiple times; this may be a bug in the model. "
235
+ )
245
236
} else if let liveMessage = LiveServerMessage ( from: response) {
246
237
if case let . goingAwayNotice( message) = liveMessage. payload {
247
238
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
@@ -255,21 +246,7 @@ actor LiveSessionService {
255
246
}
256
247
}
257
248
} catch {
258
- if let error = error as? WebSocketClosedError {
259
- // only raise an error if the session didn't close normally (ie; the user calling close)
260
- if error. closeCode != . goingAway {
261
- let closureError : Error
262
- if let error = error. underlyingError as? NSError , error. domain == NSURLErrorDomain,
263
- error. code == NSURLErrorNetworkConnectionLost {
264
- closureError = LiveSessionLostConnectionError ( underlyingError: error)
265
- } else {
266
- closureError = LiveSessionUnexpectedClosureError ( underlyingError: error)
267
- }
268
- close ( )
269
- responseContinuation. finish ( throwing: closureError)
270
- }
271
- } else {
272
- // an error occurred outside the websocket, so it's likely not closed
249
+ if let error = mapWebsocketError ( error) {
273
250
close ( )
274
251
responseContinuation. finish ( throwing: error)
275
252
}
@@ -278,22 +255,7 @@ actor LiveSessionService {
278
255
279
256
messageQueueTask = Task {
280
257
for await message in messageQueue {
281
- // we don't propogate errors, since those are surfaced in the responses stream
282
- guard let _ = try ? await setupTask. value else {
283
- break
284
- }
285
-
286
- let data : Data
287
- do {
288
- data = try jsonEncoder. encode ( message)
289
- } catch {
290
- AILog . error ( code: . liveSessionFailedToEncodeClientMessage, error. localizedDescription)
291
- AILog . debug (
292
- code: . liveSessionFailedToEncodeClientMessagePayload,
293
- String ( describing: message)
294
- )
295
- continue
296
- }
258
+ guard let data = encodeClientMessage ( message) else { continue }
297
259
298
260
do {
299
261
try await webSocket. send ( . data( data) )
@@ -304,6 +266,75 @@ actor LiveSessionService {
304
266
}
305
267
}
306
268
269
+ /// Checks if an error should be propogated up, and maps it accordingly.
270
+ ///
271
+ /// Some errors have public api alternatives. This function will ensure they're mapped
272
+ /// accordingly.
273
+ private func mapWebsocketError( _ error: Error ) -> Error ? {
274
+ if let error = error as? WebSocketClosedError {
275
+ // only raise an error if the session didn't close normally (ie; the user calling close)
276
+ if error. closeCode == . goingAway {
277
+ return nil
278
+ }
279
+
280
+ let closureError : Error
281
+
282
+ if let error = error. underlyingError as? NSError , error. domain == NSURLErrorDomain,
283
+ error. code == NSURLErrorNetworkConnectionLost {
284
+ closureError = LiveSessionLostConnectionError ( underlyingError: error)
285
+ } else {
286
+ closureError = LiveSessionUnexpectedClosureError ( underlyingError: error)
287
+ }
288
+
289
+ return closureError
290
+ }
291
+
292
+ return error
293
+ }
294
+
295
+ /// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`.
296
+ ///
297
+ /// Will throw an error if decoding fails.
298
+ private func decodeServerMessage( _ message: Data ) throws -> BidiGenerateContentServerMessage {
299
+ do {
300
+ return try jsonDecoder. decode (
301
+ BidiGenerateContentServerMessage . self,
302
+ from: message
303
+ )
304
+ } catch {
305
+ // only log the json if it wasn't a decoding error, but an unsupported message type
306
+ if error is InvalidMessageTypeError {
307
+ AILog . error (
308
+ code: . liveSessionUnsupportedMessage,
309
+ " The server sent a message that we don't currently have a mapping for. "
310
+ )
311
+ AILog . debug (
312
+ code: . liveSessionUnsupportedMessagePayload,
313
+ message. encodeToJsonString ( ) ?? " \( message) "
314
+ )
315
+ }
316
+
317
+ throw LiveSessionUnsupportedMessageError ( underlyingError: error)
318
+ }
319
+ }
320
+
321
+ /// Encodes a message from the client into `Data` that can be sent through a websocket data frame.
322
+ ///
323
+ /// Will return `nil` if decoding fails, and log an error describing why.
324
+ private func encodeClientMessage( _ message: BidiGenerateContentClientMessage ) -> Data ? {
325
+ do {
326
+ return try jsonEncoder. encode ( message)
327
+ } catch {
328
+ AILog . error ( code: . liveSessionFailedToEncodeClientMessage, error. localizedDescription)
329
+ AILog . debug (
330
+ code: . liveSessionFailedToEncodeClientMessagePayload,
331
+ String ( describing: message)
332
+ )
333
+ }
334
+
335
+ return nil
336
+ }
337
+
307
338
/// Creates a websocket pointing to the backend.
308
339
///
309
340
/// Will apply the required app check and auth headers, as the backend expects them.
@@ -392,3 +423,8 @@ private extension String {
392
423
}
393
424
}
394
425
}
426
+
427
+ /// Helper alias for a compact mapped throwing stream.
428
+ ///
429
+ /// We use this to make signatures easier to read, since we can't support `AsyncSequence` quite yet.
430
+ private typealias MappedStream < T, V> = AsyncCompactMapSequence < AsyncThrowingStream < T , any Error > , V >
0 commit comments