Skip to content

Commit 541ac34

Browse files
authored
fix(ai): Fix error propagation during setup (#15379)
1 parent d760a89 commit 541ac34

File tree

2 files changed

+171
-127
lines changed

2 files changed

+171
-127
lines changed

FirebaseAI/Sources/AILog.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ enum AILog {
7474
case liveSessionFailedToSendClientMessage = 3023
7575
case liveSessionUnexpectedResponse = 3024
7676
case liveSessionGoingAwaySoon = 3025
77-
case decodedMissingProtoDurationSuffix = 3026
78-
case decodedInvalidProtoDurationString = 3027
79-
case decodedInvalidProtoDurationSeconds = 3028
80-
case decodedInvalidProtoDurationNanoseconds = 3029
77+
case liveSessionClosedDuringSetup = 3026
78+
case decodedMissingProtoDurationSuffix = 3027
79+
case decodedInvalidProtoDurationString = 3028
80+
case decodedInvalidProtoDurationSeconds = 3029
81+
case decodedInvalidProtoDurationNanoseconds = 3030
8182

8283
// SDK State Errors
8384
case generateContentResponseNoCandidates = 4000

FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift

Lines changed: 166 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@ actor LiveSessionService {
5454
private let jsonEncoder = JSONEncoder()
5555
private let jsonDecoder = JSONDecoder()
5656

57-
/// Task that doesn't complete until the server sends a setupComplete message.
58-
///
59-
/// Used to hold off on sending messages until the server is ready.
60-
private var setupTask: Task<Void, Error>
61-
6257
/// Long running task that that wraps around the websocket, propogating messages through the
6358
/// public stream.
6459
private var responsesTask: Task<Void, Never>?
@@ -87,11 +82,9 @@ actor LiveSessionService {
8782
self.toolConfig = toolConfig
8883
self.systemInstruction = systemInstruction
8984
self.requestOptions = requestOptions
90-
setupTask = Task {}
9185
}
9286

9387
deinit {
94-
setupTask.cancel()
9588
responsesTask?.cancel()
9689
messageQueueTask?.cancel()
9790
webSocket?.disconnect()
@@ -114,29 +107,20 @@ actor LiveSessionService {
114107
///
115108
/// Seperated into its own function to make it easier to surface a way to call it seperately when
116109
/// resuming the same session.
110+
///
111+
/// This function will yield until the websocket is ready to communicate with the client.
117112
func connect() async throws {
118113
close()
119-
// we launch the setup task in a seperate task to allow us to cancel it via close
120-
setupTask = Task { [weak self] in
121-
// we need a continuation to surface that the setup is complete, while still allowing us to
122-
// listen to the server
123-
try await withCheckedThrowingContinuation { setupContinuation in
124-
// nested task so we can use await
125-
Task { [weak self] in
126-
guard let self else { return }
127-
await self.listenToServer(setupContinuation)
128-
}
129-
}
130-
}
131114

132-
try await setupTask.value
115+
let stream = try await setupWebsocket()
116+
try await waitForSetupComplete(stream: stream)
117+
spawnMessageTasks(stream: stream)
133118
}
134119

135120
/// Cancel any running tasks and close the websocket.
136121
///
137122
/// This method is idempotent; if it's already ran once, it will effectively be a no-op.
138123
func close() {
139-
setupTask.cancel()
140124
responsesTask?.cancel()
141125
messageQueueTask?.cancel()
142126
webSocket?.disconnect()
@@ -146,38 +130,19 @@ actor LiveSessionService {
146130
messageQueueTask = nil
147131
}
148132

149-
/// Start a fresh websocket to the backend, and listen for responses.
133+
/// Performs the initial setup procedure for the model.
150134
///
151-
/// Will hold off on sending any messages until the server sends a setupComplete message.
135+
/// The setup procedure with the model follows the procedure:
152136
///
153-
/// Will also close out the old websocket and the previous long running tasks.
154-
private func listenToServer(_ setupComplete: CheckedContinuation<Void, any Error>) async {
155-
do {
156-
webSocket = try await createWebsocket()
157-
} catch {
158-
let error = LiveSessionSetupError(underlyingError: error)
159-
close()
160-
setupComplete.resume(throwing: error)
161-
return
162-
}
163-
137+
/// - Client sends `BidiGenerateContentSetup`
138+
/// - Server sends back `BidiGenerateContentSetupComplete` when it's ready
139+
///
140+
/// This function will yield until the setup is complete.
141+
private func waitForSetupComplete(stream: MappedStream<
142+
URLSessionWebSocketTask.Message,
143+
Data
144+
>) async throws {
164145
guard let webSocket else { return }
165-
let stream = webSocket.connect()
166-
167-
var resumed = false
168-
169-
// remove the uncommon (and unexpected) responses from the stream, to make normal path cleaner
170-
let dataStream = stream.compactMap { (message: URLSessionWebSocketTask.Message) -> Data? in
171-
switch message {
172-
case let .string(string):
173-
AILog.error(code: .liveSessionUnexpectedResponse, "Unexpected string response: \(string)")
174-
case let .data(data):
175-
return data
176-
@unknown default:
177-
AILog.error(code: .liveSessionUnexpectedResponse, "Unknown message received: \(message)")
178-
}
179-
return nil
180-
}
181146

182147
do {
183148
let setup = BidiGenerateContentSetup(
@@ -194,54 +159,87 @@ actor LiveSessionService {
194159
} catch {
195160
let error = LiveSessionSetupError(underlyingError: error)
196161
close()
197-
setupComplete.resume(throwing: error)
198-
return
162+
throw error
199163
}
200164

201-
responsesTask = Task {
202-
do {
203-
for try await message in dataStream {
204-
let response: BidiGenerateContentServerMessage
205-
do {
206-
response = try jsonDecoder.decode(
207-
BidiGenerateContentServerMessage.self,
208-
from: message
209-
)
210-
} catch {
211-
// only log the json if it wasn't a decoding error, but an unsupported message type
212-
if error is InvalidMessageTypeError {
213-
AILog.error(
214-
code: .liveSessionUnsupportedMessage,
215-
"The server sent a message that we don't currently have a mapping for."
216-
)
165+
do {
166+
for try await message in stream {
167+
let response = try decodeServerMessage(message)
168+
if case .setupComplete = response.messageType {
169+
break
170+
} else {
171+
AILog.error(
172+
code: .liveSessionUnexpectedResponse,
173+
"The model sent us a message that wasn't a setup complete: \(response)"
174+
)
175+
}
176+
}
177+
} catch {
178+
if let error = mapWebsocketError(error) {
179+
close()
180+
throw error
181+
}
182+
// the user called close while setup was running
183+
// this can't currently happen, but could when we add automatic session resumption
184+
// in such case, we don't want to raise an error. this log is more-so to catch any edge cases
185+
AILog.debug(
186+
code: .liveSessionClosedDuringSetup,
187+
"The live session was closed before setup could complete: \(error.localizedDescription)"
188+
)
189+
}
190+
}
217191

218-
AILog.debug(
219-
code: .liveSessionUnsupportedMessagePayload,
220-
message.encodeToJsonString() ?? "\(message)"
221-
)
222-
}
192+
/// Performs the initial setup procedure for a websocket.
193+
///
194+
/// This includes creating the websocket url and connecting it.
195+
///
196+
/// - Returns: A stream of `Data` frames from the websocket.
197+
private func setupWebsocket() async throws
198+
-> MappedStream<URLSessionWebSocketTask.Message, Data> {
199+
do {
200+
let webSocket = try await createWebsocket()
201+
self.webSocket = webSocket
202+
203+
let stream = webSocket.connect()
204+
205+
// remove the uncommon (and unexpected) frames from the stream, to make normal path cleaner
206+
return stream.compactMap { message in
207+
switch message {
208+
case let .string(string):
209+
AILog.error(code: .liveSessionUnexpectedResponse, "Unexpected string response: \(string)")
210+
case let .data(data):
211+
return data
212+
@unknown default:
213+
AILog.error(code: .liveSessionUnexpectedResponse, "Unknown message received: \(message)")
214+
}
215+
return nil
216+
}
217+
} catch {
218+
let error = LiveSessionSetupError(underlyingError: error)
219+
close()
220+
throw error
221+
}
222+
}
223223

224-
let error = LiveSessionUnsupportedMessageError(underlyingError: error)
225-
// if we've already finished setting up, then only surface the error through responses
226-
// otherwise, make the setup task error as well
227-
if !resumed {
228-
setupComplete.resume(throwing: error)
229-
}
230-
throw error
231-
}
224+
/// Spawn tasks for interacting with the model.
225+
///
226+
/// The following tasks will be spawned:
227+
///
228+
/// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
229+
/// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
230+
private func spawnMessageTasks(stream: MappedStream<URLSessionWebSocketTask.Message, Data>) {
231+
guard let webSocket else { return }
232+
233+
responsesTask = Task {
234+
do {
235+
for try await message in stream {
236+
let response = try decodeServerMessage(message)
232237

233238
if case .setupComplete = response.messageType {
234-
if resumed {
235-
AILog.debug(
236-
code: .duplicateLiveSessionSetupComplete,
237-
"Setup complete was received multiple times; this may be a bug in the model."
238-
)
239-
} else {
240-
// calling resume multiple times is an error in swift, so we catch multiple calls
241-
// to avoid causing any issues due to model quirks
242-
resumed = true
243-
setupComplete.resume()
244-
}
239+
AILog.debug(
240+
code: .duplicateLiveSessionSetupComplete,
241+
"Setup complete was received multiple times; this may be a bug in the model."
242+
)
245243
} else if let liveMessage = LiveServerMessage(from: response) {
246244
if case let .goingAwayNotice(message) = liveMessage.payload {
247245
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
@@ -255,21 +253,7 @@ actor LiveSessionService {
255253
}
256254
}
257255
} catch {
258-
if let error = error as? WebSocketClosedError {
259-
// only raise an error if the session didn't close normally (ie; the user calling close)
260-
if error.closeCode != .goingAway {
261-
let closureError: Error
262-
if let error = error.underlyingError as? NSError, error.domain == NSURLErrorDomain,
263-
error.code == NSURLErrorNetworkConnectionLost {
264-
closureError = LiveSessionLostConnectionError(underlyingError: error)
265-
} else {
266-
closureError = LiveSessionUnexpectedClosureError(underlyingError: error)
267-
}
268-
close()
269-
responseContinuation.finish(throwing: closureError)
270-
}
271-
} else {
272-
// an error occurred outside the websocket, so it's likely not closed
256+
if let error = mapWebsocketError(error) {
273257
close()
274258
responseContinuation.finish(throwing: error)
275259
}
@@ -278,22 +262,7 @@ actor LiveSessionService {
278262

279263
messageQueueTask = Task {
280264
for await message in messageQueue {
281-
// we don't propogate errors, since those are surfaced in the responses stream
282-
guard let _ = try? await setupTask.value else {
283-
break
284-
}
285-
286-
let data: Data
287-
do {
288-
data = try jsonEncoder.encode(message)
289-
} catch {
290-
AILog.error(code: .liveSessionFailedToEncodeClientMessage, error.localizedDescription)
291-
AILog.debug(
292-
code: .liveSessionFailedToEncodeClientMessagePayload,
293-
String(describing: message)
294-
)
295-
continue
296-
}
265+
guard let data = encodeClientMessage(message) else { continue }
297266

298267
do {
299268
try await webSocket.send(.data(data))
@@ -304,6 +273,75 @@ actor LiveSessionService {
304273
}
305274
}
306275

276+
/// Checks if an error should be propogated up, and maps it accordingly.
277+
///
278+
/// Some errors have public api alternatives. This function will ensure they're mapped
279+
/// accordingly.
280+
private func mapWebsocketError(_ error: Error) -> Error? {
281+
if let error = error as? WebSocketClosedError {
282+
// only raise an error if the session didn't close normally (ie; the user calling close)
283+
if error.closeCode == .goingAway {
284+
return nil
285+
}
286+
287+
let closureError: Error
288+
289+
if let error = error.underlyingError as? NSError, error.domain == NSURLErrorDomain,
290+
error.code == NSURLErrorNetworkConnectionLost {
291+
closureError = LiveSessionLostConnectionError(underlyingError: error)
292+
} else {
293+
closureError = LiveSessionUnexpectedClosureError(underlyingError: error)
294+
}
295+
296+
return closureError
297+
}
298+
299+
return error
300+
}
301+
302+
/// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`.
303+
///
304+
/// Will throw an error if decoding fails.
305+
private func decodeServerMessage(_ message: Data) throws -> BidiGenerateContentServerMessage {
306+
do {
307+
return try jsonDecoder.decode(
308+
BidiGenerateContentServerMessage.self,
309+
from: message
310+
)
311+
} catch {
312+
// only log the json if it wasn't a decoding error, but an unsupported message type
313+
if error is InvalidMessageTypeError {
314+
AILog.error(
315+
code: .liveSessionUnsupportedMessage,
316+
"The server sent a message that we don't currently have a mapping for."
317+
)
318+
AILog.debug(
319+
code: .liveSessionUnsupportedMessagePayload,
320+
message.encodeToJsonString() ?? "\(message)"
321+
)
322+
}
323+
324+
throw LiveSessionUnsupportedMessageError(underlyingError: error)
325+
}
326+
}
327+
328+
/// Encodes a message from the client into `Data` that can be sent through a websocket data frame.
329+
///
330+
/// Will return `nil` if decoding fails, and log an error describing why.
331+
private func encodeClientMessage(_ message: BidiGenerateContentClientMessage) -> Data? {
332+
do {
333+
return try jsonEncoder.encode(message)
334+
} catch {
335+
AILog.error(code: .liveSessionFailedToEncodeClientMessage, error.localizedDescription)
336+
AILog.debug(
337+
code: .liveSessionFailedToEncodeClientMessagePayload,
338+
String(describing: message)
339+
)
340+
}
341+
342+
return nil
343+
}
344+
307345
/// Creates a websocket pointing to the backend.
308346
///
309347
/// Will apply the required app check and auth headers, as the backend expects them.
@@ -392,3 +430,8 @@ private extension String {
392430
}
393431
}
394432
}
433+
434+
/// Helper alias for a compact mapped throwing stream.
435+
///
436+
/// We use this to make signatures easier to read, since we can't support `AsyncSequence` quite yet.
437+
private typealias MappedStream<T, V> = AsyncCompactMapSequence<AsyncThrowingStream<T, any Error>, V>

0 commit comments

Comments
 (0)