@@ -54,11 +54,6 @@ actor LiveSessionService {
54
54
private let jsonEncoder = JSONEncoder ( )
55
55
private let jsonDecoder = JSONDecoder ( )
56
56
57
- /// Task that doesn't complete until the server sends a setupComplete message.
58
- ///
59
- /// Used to hold off on sending messages until the server is ready.
60
- private var setupTask : Task < Void , Error >
61
-
62
57
/// Long running task that that wraps around the websocket, propogating messages through the
63
58
/// public stream.
64
59
private var responsesTask : Task < Void , Never > ?
@@ -87,11 +82,9 @@ actor LiveSessionService {
87
82
self . toolConfig = toolConfig
88
83
self . systemInstruction = systemInstruction
89
84
self . requestOptions = requestOptions
90
- setupTask = Task { }
91
85
}
92
86
93
87
deinit {
94
- setupTask. cancel ( )
95
88
responsesTask? . cancel ( )
96
89
messageQueueTask? . cancel ( )
97
90
webSocket? . disconnect ( )
@@ -114,29 +107,20 @@ actor LiveSessionService {
114
107
///
115
108
/// Seperated into its own function to make it easier to surface a way to call it seperately when
116
109
/// resuming the same session.
110
+ ///
111
+ /// This function will yield until the websocket is ready to communicate with the client.
117
112
func connect( ) async throws {
118
113
close ( )
119
- // we launch the setup task in a seperate task to allow us to cancel it via close
120
- setupTask = Task { [ weak self] in
121
- // we need a continuation to surface that the setup is complete, while still allowing us to
122
- // listen to the server
123
- try await withCheckedThrowingContinuation { setupContinuation in
124
- // nested task so we can use await
125
- Task { [ weak self] in
126
- guard let self else { return }
127
- await self . listenToServer ( setupContinuation)
128
- }
129
- }
130
- }
131
114
132
- try await setupTask. value
115
+ let stream = try await setupWebsocket ( )
116
+ try await waitForSetupComplete ( stream: stream)
117
+ spawnMessageTasks ( stream: stream)
133
118
}
134
119
135
120
/// Cancel any running tasks and close the websocket.
136
121
///
137
122
/// This method is idempotent; if it's already ran once, it will effectively be a no-op.
138
123
func close( ) {
139
- setupTask. cancel ( )
140
124
responsesTask? . cancel ( )
141
125
messageQueueTask? . cancel ( )
142
126
webSocket? . disconnect ( )
@@ -146,38 +130,19 @@ actor LiveSessionService {
146
130
messageQueueTask = nil
147
131
}
148
132
149
- /// Start a fresh websocket to the backend, and listen for responses .
133
+ /// Performs the initial setup procedure for the model .
150
134
///
151
- /// Will hold off on sending any messages until the server sends a setupComplete message.
135
+ /// The setup procedure with the model follows the procedure:
152
136
///
153
- /// Will also close out the old websocket and the previous long running tasks.
154
- private func listenToServer( _ setupComplete: CheckedContinuation < Void , any Error > ) async {
155
- do {
156
- webSocket = try await createWebsocket ( )
157
- } catch {
158
- let error = LiveSessionSetupError ( underlyingError: error)
159
- close ( )
160
- setupComplete. resume ( throwing: error)
161
- return
162
- }
163
-
137
+ /// - Client sends `BidiGenerateContentSetup`
138
+ /// - Server sends back `BidiGenerateContentSetupComplete` when it's ready
139
+ ///
140
+ /// This function will yield until the setup is complete.
141
+ private func waitForSetupComplete( stream: MappedStream <
142
+ URLSessionWebSocketTask . Message ,
143
+ Data
144
+ > ) async throws {
164
145
guard let webSocket else { return }
165
- let stream = webSocket. connect ( )
166
-
167
- var resumed = false
168
-
169
- // remove the uncommon (and unexpected) responses from the stream, to make normal path cleaner
170
- let dataStream = stream. compactMap { ( message: URLSessionWebSocketTask . Message ) -> Data ? in
171
- switch message {
172
- case let . string( string) :
173
- AILog . error ( code: . liveSessionUnexpectedResponse, " Unexpected string response: \( string) " )
174
- case let . data( data) :
175
- return data
176
- @unknown default :
177
- AILog . error ( code: . liveSessionUnexpectedResponse, " Unknown message received: \( message) " )
178
- }
179
- return nil
180
- }
181
146
182
147
do {
183
148
let setup = BidiGenerateContentSetup (
@@ -194,54 +159,87 @@ actor LiveSessionService {
194
159
} catch {
195
160
let error = LiveSessionSetupError ( underlyingError: error)
196
161
close ( )
197
- setupComplete. resume ( throwing: error)
198
- return
162
+ throw error
199
163
}
200
164
201
- responsesTask = Task {
202
- do {
203
- for try await message in dataStream {
204
- let response : BidiGenerateContentServerMessage
205
- do {
206
- response = try jsonDecoder. decode (
207
- BidiGenerateContentServerMessage . self,
208
- from: message
209
- )
210
- } catch {
211
- // only log the json if it wasn't a decoding error, but an unsupported message type
212
- if error is InvalidMessageTypeError {
213
- AILog . error (
214
- code: . liveSessionUnsupportedMessage,
215
- " The server sent a message that we don't currently have a mapping for. "
216
- )
165
+ do {
166
+ for try await message in stream {
167
+ let response = try decodeServerMessage ( message)
168
+ if case . setupComplete = response. messageType {
169
+ break
170
+ } else {
171
+ AILog . error (
172
+ code: . liveSessionUnexpectedResponse,
173
+ " The model sent us a message that wasn't a setup complete: \( response) "
174
+ )
175
+ }
176
+ }
177
+ } catch {
178
+ if let error = mapWebsocketError ( error) {
179
+ close ( )
180
+ throw error
181
+ }
182
+ // the user called close while setup was running
183
+ // this can't currently happen, but could when we add automatic session resumption
184
+ // in such case, we don't want to raise an error. this log is more-so to catch any edge cases
185
+ AILog . debug (
186
+ code: . liveSessionClosedDuringSetup,
187
+ " The live session was closed before setup could complete: \( error. localizedDescription) "
188
+ )
189
+ }
190
+ }
217
191
218
- AILog . debug (
219
- code: . liveSessionUnsupportedMessagePayload,
220
- message. encodeToJsonString ( ) ?? " \( message) "
221
- )
222
- }
192
+ /// Performs the initial setup procedure for a websocket.
193
+ ///
194
+ /// This includes creating the websocket url and connecting it.
195
+ ///
196
+ /// - Returns: A stream of `Data` frames from the websocket.
197
+ private func setupWebsocket( ) async throws
198
+ -> MappedStream < URLSessionWebSocketTask . Message , Data > {
199
+ do {
200
+ let webSocket = try await createWebsocket ( )
201
+ self . webSocket = webSocket
202
+
203
+ let stream = webSocket. connect ( )
204
+
205
+ // remove the uncommon (and unexpected) frames from the stream, to make normal path cleaner
206
+ return stream. compactMap { message in
207
+ switch message {
208
+ case let . string( string) :
209
+ AILog . error ( code: . liveSessionUnexpectedResponse, " Unexpected string response: \( string) " )
210
+ case let . data( data) :
211
+ return data
212
+ @unknown default :
213
+ AILog . error ( code: . liveSessionUnexpectedResponse, " Unknown message received: \( message) " )
214
+ }
215
+ return nil
216
+ }
217
+ } catch {
218
+ let error = LiveSessionSetupError ( underlyingError: error)
219
+ close ( )
220
+ throw error
221
+ }
222
+ }
223
223
224
- let error = LiveSessionUnsupportedMessageError ( underlyingError: error)
225
- // if we've already finished setting up, then only surface the error through responses
226
- // otherwise, make the setup task error as well
227
- if !resumed {
228
- setupComplete. resume ( throwing: error)
229
- }
230
- throw error
231
- }
224
+ /// Spawn tasks for interacting with the model.
225
+ ///
226
+ /// The following tasks will be spawned:
227
+ ///
228
+ /// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
229
+ /// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
230
+ private func spawnMessageTasks( stream: MappedStream < URLSessionWebSocketTask . Message , Data > ) {
231
+ guard let webSocket else { return }
232
+
233
+ responsesTask = Task {
234
+ do {
235
+ for try await message in stream {
236
+ let response = try decodeServerMessage ( message)
232
237
233
238
if case . setupComplete = response. messageType {
234
- if resumed {
235
- AILog . debug (
236
- code: . duplicateLiveSessionSetupComplete,
237
- " Setup complete was received multiple times; this may be a bug in the model. "
238
- )
239
- } else {
240
- // calling resume multiple times is an error in swift, so we catch multiple calls
241
- // to avoid causing any issues due to model quirks
242
- resumed = true
243
- setupComplete. resume ( )
244
- }
239
+ AILog . debug (
240
+ code: . duplicateLiveSessionSetupComplete,
241
+ " Setup complete was received multiple times; this may be a bug in the model. "
242
+ )
245
243
} else if let liveMessage = LiveServerMessage ( from: response) {
246
244
if case let . goingAwayNotice( message) = liveMessage. payload {
247
245
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
@@ -255,21 +253,7 @@ actor LiveSessionService {
255
253
}
256
254
}
257
255
} catch {
258
- if let error = error as? WebSocketClosedError {
259
- // only raise an error if the session didn't close normally (ie; the user calling close)
260
- if error. closeCode != . goingAway {
261
- let closureError : Error
262
- if let error = error. underlyingError as? NSError , error. domain == NSURLErrorDomain,
263
- error. code == NSURLErrorNetworkConnectionLost {
264
- closureError = LiveSessionLostConnectionError ( underlyingError: error)
265
- } else {
266
- closureError = LiveSessionUnexpectedClosureError ( underlyingError: error)
267
- }
268
- close ( )
269
- responseContinuation. finish ( throwing: closureError)
270
- }
271
- } else {
272
- // an error occurred outside the websocket, so it's likely not closed
256
+ if let error = mapWebsocketError ( error) {
273
257
close ( )
274
258
responseContinuation. finish ( throwing: error)
275
259
}
@@ -278,22 +262,7 @@ actor LiveSessionService {
278
262
279
263
messageQueueTask = Task {
280
264
for await message in messageQueue {
281
- // we don't propogate errors, since those are surfaced in the responses stream
282
- guard let _ = try ? await setupTask. value else {
283
- break
284
- }
285
-
286
- let data : Data
287
- do {
288
- data = try jsonEncoder. encode ( message)
289
- } catch {
290
- AILog . error ( code: . liveSessionFailedToEncodeClientMessage, error. localizedDescription)
291
- AILog . debug (
292
- code: . liveSessionFailedToEncodeClientMessagePayload,
293
- String ( describing: message)
294
- )
295
- continue
296
- }
265
+ guard let data = encodeClientMessage ( message) else { continue }
297
266
298
267
do {
299
268
try await webSocket. send ( . data( data) )
@@ -304,6 +273,75 @@ actor LiveSessionService {
304
273
}
305
274
}
306
275
276
+ /// Checks if an error should be propogated up, and maps it accordingly.
277
+ ///
278
+ /// Some errors have public api alternatives. This function will ensure they're mapped
279
+ /// accordingly.
280
+ private func mapWebsocketError( _ error: Error ) -> Error ? {
281
+ if let error = error as? WebSocketClosedError {
282
+ // only raise an error if the session didn't close normally (ie; the user calling close)
283
+ if error. closeCode == . goingAway {
284
+ return nil
285
+ }
286
+
287
+ let closureError : Error
288
+
289
+ if let error = error. underlyingError as? NSError , error. domain == NSURLErrorDomain,
290
+ error. code == NSURLErrorNetworkConnectionLost {
291
+ closureError = LiveSessionLostConnectionError ( underlyingError: error)
292
+ } else {
293
+ closureError = LiveSessionUnexpectedClosureError ( underlyingError: error)
294
+ }
295
+
296
+ return closureError
297
+ }
298
+
299
+ return error
300
+ }
301
+
302
+ /// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`.
303
+ ///
304
+ /// Will throw an error if decoding fails.
305
+ private func decodeServerMessage( _ message: Data ) throws -> BidiGenerateContentServerMessage {
306
+ do {
307
+ return try jsonDecoder. decode (
308
+ BidiGenerateContentServerMessage . self,
309
+ from: message
310
+ )
311
+ } catch {
312
+ // only log the json if it wasn't a decoding error, but an unsupported message type
313
+ if error is InvalidMessageTypeError {
314
+ AILog . error (
315
+ code: . liveSessionUnsupportedMessage,
316
+ " The server sent a message that we don't currently have a mapping for. "
317
+ )
318
+ AILog . debug (
319
+ code: . liveSessionUnsupportedMessagePayload,
320
+ message. encodeToJsonString ( ) ?? " \( message) "
321
+ )
322
+ }
323
+
324
+ throw LiveSessionUnsupportedMessageError ( underlyingError: error)
325
+ }
326
+ }
327
+
328
+ /// Encodes a message from the client into `Data` that can be sent through a websocket data frame.
329
+ ///
330
+ /// Will return `nil` if decoding fails, and log an error describing why.
331
+ private func encodeClientMessage( _ message: BidiGenerateContentClientMessage ) -> Data ? {
332
+ do {
333
+ return try jsonEncoder. encode ( message)
334
+ } catch {
335
+ AILog . error ( code: . liveSessionFailedToEncodeClientMessage, error. localizedDescription)
336
+ AILog . debug (
337
+ code: . liveSessionFailedToEncodeClientMessagePayload,
338
+ String ( describing: message)
339
+ )
340
+ }
341
+
342
+ return nil
343
+ }
344
+
307
345
/// Creates a websocket pointing to the backend.
308
346
///
309
347
/// Will apply the required app check and auth headers, as the backend expects them.
@@ -392,3 +430,8 @@ private extension String {
392
430
}
393
431
}
394
432
}
433
+
434
+ /// Helper alias for a compact mapped throwing stream.
435
+ ///
436
+ /// We use this to make signatures easier to read, since we can't support `AsyncSequence` quite yet.
437
+ private typealias MappedStream < T, V> = AsyncCompactMapSequence < AsyncThrowingStream < T , any Error > , V >
0 commit comments