Skip to content

Commit 422939d

Browse files
committed
Update LiveSessionService.swift
1 parent 962bb60 commit 422939d

File tree

1 file changed

+160
-124
lines changed

1 file changed

+160
-124
lines changed

FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift

Lines changed: 160 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import Foundation
2828
///
2929
/// This mainly comes into play when we don't want to block developers from sending messages while a
3030
/// session is being reloaded.
31-
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
31+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
3232
@available(watchOS, unavailable)
3333
actor LiveSessionService {
3434
let responses: AsyncThrowingStream<LiveServerMessage, Error>
@@ -54,11 +54,6 @@ actor LiveSessionService {
5454
private let jsonEncoder = JSONEncoder()
5555
private let jsonDecoder = JSONDecoder()
5656

57-
/// Task that doesn't complete until the server sends a setupComplete message.
58-
///
59-
/// Used to hold off on sending messages until the server is ready.
60-
private var setupTask: Task<Void, Error>
61-
6257
/// Long running task that that wraps around the websocket, propogating messages through the
6358
/// public stream.
6459
private var responsesTask: Task<Void, Never>?
@@ -87,11 +82,9 @@ actor LiveSessionService {
8782
self.toolConfig = toolConfig
8883
self.systemInstruction = systemInstruction
8984
self.requestOptions = requestOptions
90-
setupTask = Task {}
9185
}
9286

9387
deinit {
94-
setupTask.cancel()
9588
responsesTask?.cancel()
9689
messageQueueTask?.cancel()
9790
webSocket?.disconnect()
@@ -114,29 +107,20 @@ actor LiveSessionService {
114107
///
115108
/// Seperated into its own function to make it easier to surface a way to call it seperately when
116109
/// resuming the same session.
110+
///
111+
/// This function will yield until the websocket is ready to communicate with the client.
117112
func connect() async throws {
118113
close()
119-
// we launch the setup task in a seperate task to allow us to cancel it via close
120-
setupTask = Task { [weak self] in
121-
// we need a continuation to surface that the setup is complete, while still allowing us to
122-
// listen to the server
123-
try await withCheckedThrowingContinuation { setupContinuation in
124-
// nested task so we can use await
125-
Task { [weak self] in
126-
guard let self else { return }
127-
await self.listenToServer(setupContinuation)
128-
}
129-
}
130-
}
131114

132-
try await setupTask.value
115+
let stream = try await setupWebsocket()
116+
try await waitForSetupComplete(stream: stream)
117+
spawnMessageTasks(stream: stream)
133118
}
134119

135120
/// Cancel any running tasks and close the websocket.
136121
///
137122
/// This method is idempotent; if it's already ran once, it will effectively be a no-op.
138123
func close() {
139-
setupTask.cancel()
140124
responsesTask?.cancel()
141125
messageQueueTask?.cancel()
142126
webSocket?.disconnect()
@@ -146,38 +130,19 @@ actor LiveSessionService {
146130
messageQueueTask = nil
147131
}
148132

149-
/// Start a fresh websocket to the backend, and listen for responses.
133+
/// Performs the initial setup procedure for the model.
150134
///
151-
/// Will hold off on sending any messages until the server sends a setupComplete message.
135+
/// The setup procedure with the model follows the procedure:
152136
///
153-
/// Will also close out the old websocket and the previous long running tasks.
154-
private func listenToServer(_ setupComplete: CheckedContinuation<Void, any Error>) async {
155-
do {
156-
webSocket = try await createWebsocket()
157-
} catch {
158-
let error = LiveSessionSetupError(underlyingError: error)
159-
close()
160-
setupComplete.resume(throwing: error)
161-
return
162-
}
163-
137+
/// - Client sends `BidiGenerateContentSetup`
138+
/// - Server sends back `BidiGenerateContentSetupComplete` when it's ready
139+
///
140+
/// This function will yield until the setup is complete.
141+
private func waitForSetupComplete(stream: MappedStream<
142+
URLSessionWebSocketTask.Message,
143+
Data
144+
>) async throws {
164145
guard let webSocket else { return }
165-
let stream = webSocket.connect()
166-
167-
var resumed = false
168-
169-
// remove the uncommon (and unexpected) responses from the stream, to make normal path cleaner
170-
let dataStream = stream.compactMap { (message: URLSessionWebSocketTask.Message) -> Data? in
171-
switch message {
172-
case let .string(string):
173-
AILog.error(code: .liveSessionUnexpectedResponse, "Unexpected string response: \(string)")
174-
case let .data(data):
175-
return data
176-
@unknown default:
177-
AILog.error(code: .liveSessionUnexpectedResponse, "Unknown message received: \(message)")
178-
}
179-
return nil
180-
}
181146

182147
do {
183148
let setup = BidiGenerateContentSetup(
@@ -194,54 +159,80 @@ actor LiveSessionService {
194159
} catch {
195160
let error = LiveSessionSetupError(underlyingError: error)
196161
close()
197-
setupComplete.resume(throwing: error)
198-
return
162+
throw error
199163
}
200164

201-
responsesTask = Task {
202-
do {
203-
for try await message in dataStream {
204-
let response: BidiGenerateContentServerMessage
205-
do {
206-
response = try jsonDecoder.decode(
207-
BidiGenerateContentServerMessage.self,
208-
from: message
209-
)
210-
} catch {
211-
// only log the json if it wasn't a decoding error, but an unsupported message type
212-
if error is InvalidMessageTypeError {
213-
AILog.error(
214-
code: .liveSessionUnsupportedMessage,
215-
"The server sent a message that we don't currently have a mapping for."
216-
)
165+
do {
166+
for try await message in stream {
167+
let response = try decodeServerMessage(message)
168+
if case .setupComplete = response.messageType {
169+
break
170+
} else {
171+
AILog.error(
172+
code: .liveSessionUnexpectedResponse,
173+
"The model sent us a message that wasn't a setup complete: \(response)"
174+
)
175+
}
176+
}
177+
} catch {
178+
if let error = mapWebsocketError(error) {
179+
close()
180+
throw error
181+
}
182+
}
183+
}
217184

218-
AILog.debug(
219-
code: .liveSessionUnsupportedMessagePayload,
220-
message.encodeToJsonString() ?? "\(message)"
221-
)
222-
}
185+
/// Performs the initial setup procedure for a websocket.
186+
///
187+
/// This includes creating the websocket url and connecting it.
188+
///
189+
/// - Returns: A stream of `Data` frames from the websocket.
190+
private func setupWebsocket() async throws
191+
-> MappedStream<URLSessionWebSocketTask.Message, Data> {
192+
do {
193+
let webSocket = try await createWebsocket()
194+
self.webSocket = webSocket
195+
196+
let stream = webSocket.connect()
197+
198+
// remove the uncommon (and unexpected) frames from the stream, to make normal path cleaner
199+
return stream.compactMap { message in
200+
switch message {
201+
case let .string(string):
202+
AILog.error(code: .liveSessionUnexpectedResponse, "Unexpected string response: \(string)")
203+
case let .data(data):
204+
return data
205+
@unknown default:
206+
AILog.error(code: .liveSessionUnexpectedResponse, "Unknown message received: \(message)")
207+
}
208+
return nil
209+
}
210+
} catch {
211+
let error = LiveSessionSetupError(underlyingError: error)
212+
close()
213+
throw error
214+
}
215+
}
223216

224-
let error = LiveSessionUnsupportedMessageError(underlyingError: error)
225-
// if we've already finished setting up, then only surface the error through responses
226-
// otherwise, make the setup task error as well
227-
if !resumed {
228-
setupComplete.resume(throwing: error)
229-
}
230-
throw error
231-
}
217+
/// Spawn tasks for interacting with the model.
218+
///
219+
/// The following tasks will be spawned:
220+
///
221+
/// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
222+
/// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
223+
private func spawnMessageTasks(stream: MappedStream<URLSessionWebSocketTask.Message, Data>) {
224+
guard let webSocket else { return }
225+
226+
responsesTask = Task {
227+
do {
228+
for try await message in stream {
229+
let response = try decodeServerMessage(message)
232230

233231
if case .setupComplete = response.messageType {
234-
if resumed {
235-
AILog.debug(
236-
code: .duplicateLiveSessionSetupComplete,
237-
"Setup complete was received multiple times; this may be a bug in the model."
238-
)
239-
} else {
240-
// calling resume multiple times is an error in swift, so we catch multiple calls
241-
// to avoid causing any issues due to model quirks
242-
resumed = true
243-
setupComplete.resume()
244-
}
232+
AILog.debug(
233+
code: .duplicateLiveSessionSetupComplete,
234+
"Setup complete was received multiple times; this may be a bug in the model."
235+
)
245236
} else if let liveMessage = LiveServerMessage(from: response) {
246237
if case let .goingAwayNotice(message) = liveMessage.payload {
247238
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
@@ -255,21 +246,7 @@ actor LiveSessionService {
255246
}
256247
}
257248
} catch {
258-
if let error = error as? WebSocketClosedError {
259-
// only raise an error if the session didn't close normally (ie; the user calling close)
260-
if error.closeCode != .goingAway {
261-
let closureError: Error
262-
if let error = error.underlyingError as? NSError, error.domain == NSURLErrorDomain,
263-
error.code == NSURLErrorNetworkConnectionLost {
264-
closureError = LiveSessionLostConnectionError(underlyingError: error)
265-
} else {
266-
closureError = LiveSessionUnexpectedClosureError(underlyingError: error)
267-
}
268-
close()
269-
responseContinuation.finish(throwing: closureError)
270-
}
271-
} else {
272-
// an error occurred outside the websocket, so it's likely not closed
249+
if let error = mapWebsocketError(error) {
273250
close()
274251
responseContinuation.finish(throwing: error)
275252
}
@@ -278,22 +255,7 @@ actor LiveSessionService {
278255

279256
messageQueueTask = Task {
280257
for await message in messageQueue {
281-
// we don't propogate errors, since those are surfaced in the responses stream
282-
guard let _ = try? await setupTask.value else {
283-
break
284-
}
285-
286-
let data: Data
287-
do {
288-
data = try jsonEncoder.encode(message)
289-
} catch {
290-
AILog.error(code: .liveSessionFailedToEncodeClientMessage, error.localizedDescription)
291-
AILog.debug(
292-
code: .liveSessionFailedToEncodeClientMessagePayload,
293-
String(describing: message)
294-
)
295-
continue
296-
}
258+
guard let data = encodeClientMessage(message) else { continue }
297259

298260
do {
299261
try await webSocket.send(.data(data))
@@ -304,6 +266,75 @@ actor LiveSessionService {
304266
}
305267
}
306268

269+
/// Checks if an error should be propogated up, and maps it accordingly.
270+
///
271+
/// Some errors have public api alternatives. This function will ensure they're mapped
272+
/// accordingly.
273+
private func mapWebsocketError(_ error: Error) -> Error? {
274+
if let error = error as? WebSocketClosedError {
275+
// only raise an error if the session didn't close normally (ie; the user calling close)
276+
if error.closeCode == .goingAway {
277+
return nil
278+
}
279+
280+
let closureError: Error
281+
282+
if let error = error.underlyingError as? NSError, error.domain == NSURLErrorDomain,
283+
error.code == NSURLErrorNetworkConnectionLost {
284+
closureError = LiveSessionLostConnectionError(underlyingError: error)
285+
} else {
286+
closureError = LiveSessionUnexpectedClosureError(underlyingError: error)
287+
}
288+
289+
return closureError
290+
}
291+
292+
return error
293+
}
294+
295+
/// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`.
296+
///
297+
/// Will throw an error if decoding fails.
298+
private func decodeServerMessage(_ message: Data) throws -> BidiGenerateContentServerMessage {
299+
do {
300+
return try jsonDecoder.decode(
301+
BidiGenerateContentServerMessage.self,
302+
from: message
303+
)
304+
} catch {
305+
// only log the json if it wasn't a decoding error, but an unsupported message type
306+
if error is InvalidMessageTypeError {
307+
AILog.error(
308+
code: .liveSessionUnsupportedMessage,
309+
"The server sent a message that we don't currently have a mapping for."
310+
)
311+
AILog.debug(
312+
code: .liveSessionUnsupportedMessagePayload,
313+
message.encodeToJsonString() ?? "\(message)"
314+
)
315+
}
316+
317+
throw LiveSessionUnsupportedMessageError(underlyingError: error)
318+
}
319+
}
320+
321+
/// Encodes a message from the client into `Data` that can be sent through a websocket data frame.
322+
///
323+
/// Will return `nil` if decoding fails, and log an error describing why.
324+
private func encodeClientMessage(_ message: BidiGenerateContentClientMessage) -> Data? {
325+
do {
326+
return try jsonEncoder.encode(message)
327+
} catch {
328+
AILog.error(code: .liveSessionFailedToEncodeClientMessage, error.localizedDescription)
329+
AILog.debug(
330+
code: .liveSessionFailedToEncodeClientMessagePayload,
331+
String(describing: message)
332+
)
333+
}
334+
335+
return nil
336+
}
337+
307338
/// Creates a websocket pointing to the backend.
308339
///
309340
/// Will apply the required app check and auth headers, as the backend expects them.
@@ -392,3 +423,8 @@ private extension String {
392423
}
393424
}
394425
}
426+
427+
/// Helper alias for a compact mapped throwing stream.
428+
///
429+
/// We use this to make signatures easier to read, since we can't support `AsyncSequence` quite yet.
430+
private typealias MappedStream<T, V> = AsyncCompactMapSequence<AsyncThrowingStream<T, any Error>, V>

0 commit comments

Comments
 (0)