Skip to content

Commit e783431

Browse files
committed
Emit responses from LiveSession
1 parent 05fc2ff commit e783431

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

FirebaseAI/Sources/Types/Internal/Live/BidiGenerateContentServerMessage.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import Foundation
1616

1717
/// Response message for BidiGenerateContent RPC call.
18-
struct BidiGenerateContentServerMessage {
18+
public struct BidiGenerateContentServerMessage {
19+
// TODO: Make this type `internal`
20+
1921
/// The type of the message.
2022
enum MessageType {
2123
/// Sent in response to a `BidiGenerateContentSetup` message from the client.
@@ -56,7 +58,7 @@ extension BidiGenerateContentServerMessage: Decodable {
5658
case usageMetadata
5759
}
5860

59-
init(from decoder: any Decoder) throws {
61+
public init(from decoder: any Decoder) throws {
6062
let container = try decoder.container(keyedBy: CodingKeys.self)
6163

6264
if let setupComplete = try container.decodeIfPresent(

FirebaseAI/Sources/Types/Public/Live/LiveSession.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,26 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
3232
let generationConfig: LiveGenerationConfig?
3333
let webSocket: URLSessionWebSocketTask
3434

35+
// TODO: Refactor this property, potentially returning responses after `connect`.
36+
public let responses: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
37+
3538
private var state: State = .notConnected
3639
private var pendingMessages: [(String, CheckedContinuation<Void, Error>)] = []
3740
private let jsonEncoder = JSONEncoder()
3841
private let jsonDecoder = JSONDecoder()
3942

43+
// TODO: Properly wrap callback code using `withCheckedContinuation` or similar.
44+
private let responseContinuation: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
45+
.Continuation
46+
4047
init(modelResourceName: String,
4148
generationConfig: LiveGenerationConfig?,
4249
url: URL,
4350
urlSession: URLSession) {
4451
self.modelResourceName = modelResourceName
4552
self.generationConfig = generationConfig
4653
webSocket = urlSession.webSocketTask(with: url)
54+
(responses, responseContinuation) = AsyncThrowingStream.makeStream()
4755
}
4856

4957
func open() async throws {
@@ -64,6 +72,7 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
6472
continuation.resume(throwing: error)
6573
}
6674
pendingMessages.removeAll()
75+
responseContinuation.finish(throwing: error)
6776
}
6877

6978
private func processPendingMessages() {
@@ -144,6 +153,7 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
144153
print("Web Socket closed.")
145154
state = .closed
146155
failPendingMessages(with: WebSocketError.connectionClosed)
156+
responseContinuation.finish()
147157
}
148158

149159
func setReceiveHandler() {
@@ -172,7 +182,6 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
172182
self.state = .ready
173183
self.processPendingMessages()
174184
case .serverContent:
175-
// TODO: Return the serverContent to the developer
176185
print("Server Content: \(responseJSON)")
177186
case .toolCall:
178187
// TODO: Tool calls not yet implemented
@@ -188,6 +197,8 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
188197
}
189198
}
190199

200+
self.responseContinuation.yield(response)
201+
191202
if self.state == .closed {
192203
print("Web socket is closed, not listening for more messages.")
193204
} else {
@@ -201,6 +212,7 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
201212
// handle the error
202213
print(error)
203214
self.state = .closed
215+
self.responseContinuation.finish(throwing: error)
204216
}
205217
}
206218
}

0 commit comments

Comments
 (0)