Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 159 additions & 123 deletions FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ actor LiveSessionService {
private let jsonEncoder = JSONEncoder()
private let jsonDecoder = JSONDecoder()

/// Task that doesn't complete until the server sends a setupComplete message.
///
/// Used to hold off on sending messages until the server is ready.
private var setupTask: Task<Void, Error>

/// Long running task that that wraps around the websocket, propogating messages through the
/// public stream.
private var responsesTask: Task<Void, Never>?
Expand Down Expand Up @@ -87,11 +82,9 @@ actor LiveSessionService {
self.toolConfig = toolConfig
self.systemInstruction = systemInstruction
self.requestOptions = requestOptions
setupTask = Task {}
}

deinit {
setupTask.cancel()
responsesTask?.cancel()
messageQueueTask?.cancel()
webSocket?.disconnect()
Expand All @@ -114,29 +107,20 @@ actor LiveSessionService {
///
/// Seperated into its own function to make it easier to surface a way to call it seperately when
/// resuming the same session.
///
/// This function will yield until the websocket is ready to communicate with the client.
func connect() async throws {
close()
// we launch the setup task in a seperate task to allow us to cancel it via close
setupTask = Task { [weak self] in
// we need a continuation to surface that the setup is complete, while still allowing us to
// listen to the server
try await withCheckedThrowingContinuation { setupContinuation in
// nested task so we can use await
Task { [weak self] in
guard let self else { return }
await self.listenToServer(setupContinuation)
}
}
}

try await setupTask.value
let stream = try await setupWebsocket()
try await waitForSetupComplete(stream: stream)
spawnMessageTasks(stream: stream)
}

/// Cancel any running tasks and close the websocket.
///
/// This method is idempotent; if it's already ran once, it will effectively be a no-op.
func close() {
setupTask.cancel()
responsesTask?.cancel()
messageQueueTask?.cancel()
webSocket?.disconnect()
Expand All @@ -146,38 +130,19 @@ actor LiveSessionService {
messageQueueTask = nil
}

/// Start a fresh websocket to the backend, and listen for responses.
/// Performs the initial setup procedure for the model.
///
/// Will hold off on sending any messages until the server sends a setupComplete message.
/// The setup procedure with the model follows the procedure:
///
/// Will also close out the old websocket and the previous long running tasks.
private func listenToServer(_ setupComplete: CheckedContinuation<Void, any Error>) async {
do {
webSocket = try await createWebsocket()
} catch {
let error = LiveSessionSetupError(underlyingError: error)
close()
setupComplete.resume(throwing: error)
return
}

/// - Client sends `BidiGenerateContentSetup`
/// - Server sends back `BidiGenerateContentSetupComplete` when it's ready
///
/// This function will yield until the setup is complete.
private func waitForSetupComplete(stream: MappedStream<
URLSessionWebSocketTask.Message,
Data
>) async throws {
guard let webSocket else { return }
let stream = webSocket.connect()

var resumed = false

// remove the uncommon (and unexpected) responses from the stream, to make normal path cleaner
let dataStream = stream.compactMap { (message: URLSessionWebSocketTask.Message) -> Data? in
switch message {
case let .string(string):
AILog.error(code: .liveSessionUnexpectedResponse, "Unexpected string response: \(string)")
case let .data(data):
return data
@unknown default:
AILog.error(code: .liveSessionUnexpectedResponse, "Unknown message received: \(message)")
}
return nil
}

do {
let setup = BidiGenerateContentSetup(
Expand All @@ -194,54 +159,80 @@ actor LiveSessionService {
} catch {
let error = LiveSessionSetupError(underlyingError: error)
close()
setupComplete.resume(throwing: error)
return
throw error
}

responsesTask = Task {
do {
for try await message in dataStream {
let response: BidiGenerateContentServerMessage
do {
response = try jsonDecoder.decode(
BidiGenerateContentServerMessage.self,
from: message
)
} catch {
// only log the json if it wasn't a decoding error, but an unsupported message type
if error is InvalidMessageTypeError {
AILog.error(
code: .liveSessionUnsupportedMessage,
"The server sent a message that we don't currently have a mapping for."
)
do {
for try await message in stream {
let response = try decodeServerMessage(message)
if case .setupComplete = response.messageType {
break
} else {
AILog.error(
code: .liveSessionUnexpectedResponse,
"The model sent us a message that wasn't a setup complete: \(response)"
)
}
}
} catch {
if let error = mapWebsocketError(error) {
close()
throw error
}
}
}

AILog.debug(
code: .liveSessionUnsupportedMessagePayload,
message.encodeToJsonString() ?? "\(message)"
)
}
/// Performs the initial setup procedure for a websocket.
///
/// This includes creating the websocket url and connecting it.
///
/// - Returns: A stream of `Data` frames from the websocket.
private func setupWebsocket() async throws
-> MappedStream<URLSessionWebSocketTask.Message, Data> {
do {
let webSocket = try await createWebsocket()
self.webSocket = webSocket

let stream = webSocket.connect()

// remove the uncommon (and unexpected) frames from the stream, to make normal path cleaner
return stream.compactMap { message in
switch message {
case let .string(string):
AILog.error(code: .liveSessionUnexpectedResponse, "Unexpected string response: \(string)")
case let .data(data):
return data
@unknown default:
AILog.error(code: .liveSessionUnexpectedResponse, "Unknown message received: \(message)")
}
return nil
}
} catch {
let error = LiveSessionSetupError(underlyingError: error)
close()
throw error
}
}

let error = LiveSessionUnsupportedMessageError(underlyingError: error)
// if we've already finished setting up, then only surface the error through responses
// otherwise, make the setup task error as well
if !resumed {
setupComplete.resume(throwing: error)
}
throw error
}
/// Spawn tasks for interacting with the model.
///
/// The following tasks will be spawned:
///
/// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
/// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
private func spawnMessageTasks(stream: MappedStream<URLSessionWebSocketTask.Message, Data>) {
guard let webSocket else { return }

responsesTask = Task {
do {
for try await message in stream {
let response = try decodeServerMessage(message)

if case .setupComplete = response.messageType {
if resumed {
AILog.debug(
code: .duplicateLiveSessionSetupComplete,
"Setup complete was received multiple times; this may be a bug in the model."
)
} else {
// calling resume multiple times is an error in swift, so we catch multiple calls
// to avoid causing any issues due to model quirks
resumed = true
setupComplete.resume()
}
AILog.debug(
code: .duplicateLiveSessionSetupComplete,
"Setup complete was received multiple times; this may be a bug in the model."
)
} else if let liveMessage = LiveServerMessage(from: response) {
if case let .goingAwayNotice(message) = liveMessage.payload {
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
Expand All @@ -255,21 +246,7 @@ actor LiveSessionService {
}
}
} catch {
if let error = error as? WebSocketClosedError {
// only raise an error if the session didn't close normally (ie; the user calling close)
if error.closeCode != .goingAway {
let closureError: Error
if let error = error.underlyingError as? NSError, error.domain == NSURLErrorDomain,
error.code == NSURLErrorNetworkConnectionLost {
closureError = LiveSessionLostConnectionError(underlyingError: error)
} else {
closureError = LiveSessionUnexpectedClosureError(underlyingError: error)
}
close()
responseContinuation.finish(throwing: closureError)
}
} else {
// an error occurred outside the websocket, so it's likely not closed
if let error = mapWebsocketError(error) {
close()
responseContinuation.finish(throwing: error)
}
Expand All @@ -278,22 +255,7 @@ actor LiveSessionService {

messageQueueTask = Task {
for await message in messageQueue {
// we don't propogate errors, since those are surfaced in the responses stream
guard let _ = try? await setupTask.value else {
break
}

let data: Data
do {
data = try jsonEncoder.encode(message)
} catch {
AILog.error(code: .liveSessionFailedToEncodeClientMessage, error.localizedDescription)
AILog.debug(
code: .liveSessionFailedToEncodeClientMessagePayload,
String(describing: message)
)
continue
}
guard let data = encodeClientMessage(message) else { continue }

do {
try await webSocket.send(.data(data))
Expand All @@ -304,6 +266,75 @@ actor LiveSessionService {
}
}

/// Checks if an error should be propogated up, and maps it accordingly.
///
/// Some errors have public api alternatives. This function will ensure they're mapped
/// accordingly.
private func mapWebsocketError(_ error: Error) -> Error? {
if let error = error as? WebSocketClosedError {
// only raise an error if the session didn't close normally (ie; the user calling close)
if error.closeCode == .goingAway {
return nil
}

let closureError: Error

if let error = error.underlyingError as? NSError, error.domain == NSURLErrorDomain,
error.code == NSURLErrorNetworkConnectionLost {
closureError = LiveSessionLostConnectionError(underlyingError: error)
} else {
closureError = LiveSessionUnexpectedClosureError(underlyingError: error)
}

return closureError
}

return error
}

/// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`.
///
/// Will throw an error if decoding fails.
private func decodeServerMessage(_ message: Data) throws -> BidiGenerateContentServerMessage {
do {
return try jsonDecoder.decode(
BidiGenerateContentServerMessage.self,
from: message
)
} catch {
// only log the json if it wasn't a decoding error, but an unsupported message type
if error is InvalidMessageTypeError {
AILog.error(
code: .liveSessionUnsupportedMessage,
"The server sent a message that we don't currently have a mapping for."
)
AILog.debug(
code: .liveSessionUnsupportedMessagePayload,
message.encodeToJsonString() ?? "\(message)"
)
}

throw LiveSessionUnsupportedMessageError(underlyingError: error)
}
}

/// Encodes a message from the client into `Data` that can be sent through a websocket data frame.
///
/// Will return `nil` if decoding fails, and log an error describing why.
private func encodeClientMessage(_ message: BidiGenerateContentClientMessage) -> Data? {
do {
return try jsonEncoder.encode(message)
} catch {
AILog.error(code: .liveSessionFailedToEncodeClientMessage, error.localizedDescription)
AILog.debug(
code: .liveSessionFailedToEncodeClientMessagePayload,
String(describing: message)
)
}

return nil
}

/// Creates a websocket pointing to the backend.
///
/// Will apply the required app check and auth headers, as the backend expects them.
Expand Down Expand Up @@ -392,3 +423,8 @@ private extension String {
}
}
}

/// Helper alias for a compact mapped throwing stream.
///
/// We use this to make signatures easier to read, since we can't support `AsyncSequence` quite yet.
private typealias MappedStream<T, V> = AsyncCompactMapSequence<AsyncThrowingStream<T, any Error>, V>
Loading