Skip to content

Commit dd89e20

Browse files
andrewhearddaymxn
authored andcommitted
Refactor to use async/await and remove URLSessionWebSocketDelegate
1 parent 91dd102 commit dd89e20

File tree

4 files changed

+35
-157
lines changed

4 files changed

+35
-157
lines changed

FirebaseAI/Sources/Types/Internal/Live/BidiGenerateContentServerMessage.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import Foundation
1616

1717
/// Response message for BidiGenerateContent RPC call.
18-
public struct BidiGenerateContentServerMessage {
18+
public struct BidiGenerateContentServerMessage: Sendable {
1919
// TODO: Make this type `internal`
2020

2121
/// The type of the message.

FirebaseAI/Sources/Types/Public/Live/LiveGenerativeModel.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ public final class LiveGenerativeModel {
4040
self.urlSession = urlSession
4141
}
4242

43-
public func connect() async throws -> LiveSession {
43+
public func connect() -> LiveSession {
4444
let liveSession = LiveSession(
4545
modelResourceName: modelResourceName,
4646
generationConfig: generationConfig,
4747
url: webSocketURL(),
4848
urlSession: urlSession
4949
)
5050
print("Opening Live Session...")
51-
try await liveSession.open()
51+
liveSession.openConnection()
5252
return liveSession
5353
}
5454

FirebaseAI/Sources/Types/Public/Live/LiveSession.swift

Lines changed: 31 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,18 @@
1515
import Foundation
1616

1717
// TODO: Extract most of this file into a service class similar to `GenerativeAIService`.
18-
public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate {
19-
private enum State {
20-
case notConnected
21-
case connecting
22-
case setupSent
23-
case ready
24-
case closed
25-
}
26-
27-
private enum WebSocketError: Error {
28-
case connectionClosed
29-
}
30-
18+
public final class LiveSession: Sendable {
3119
let modelResourceName: String
3220
let generationConfig: LiveGenerationConfig?
3321
let webSocket: URLSessionWebSocketTask
3422

35-
// TODO: Refactor this property, potentially returning responses after `connect`.
3623
public let responses: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
24+
private let responseContinuation: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
25+
.Continuation
3726

38-
private var state: State = .notConnected
39-
private var pendingMessages: [(String, CheckedContinuation<Void, Error>)] = []
4027
private let jsonEncoder = JSONEncoder()
4128
private let jsonDecoder = JSONDecoder()
4229

43-
// TODO: Properly wrap callback code using `withCheckedContinuation` or similar.
44-
private let responseContinuation: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
45-
.Continuation
46-
4730
init(modelResourceName: String,
4831
generationConfig: LiveGenerationConfig?,
4932
url: URL,
@@ -54,166 +37,61 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
5437
(responses, responseContinuation) = AsyncThrowingStream.makeStream()
5538
}
5639

57-
func open() async throws {
58-
guard state == .notConnected else {
59-
print("Web socket is not in a valid state to be opened: \(state)")
60-
return
61-
}
62-
63-
state = .connecting
64-
webSocket.delegate = self
65-
webSocket.resume()
66-
67-
print("Opening websocket")
68-
}
69-
70-
private func failPendingMessages(with error: Error) {
71-
for (_, continuation) in pendingMessages {
72-
continuation.resume(throwing: error)
73-
}
74-
pendingMessages.removeAll()
75-
responseContinuation.finish(throwing: error)
40+
deinit {
41+
webSocket.cancel(with: .goingAway, reason: nil)
7642
}
7743

78-
private func processPendingMessages() {
79-
for (message, continuation) in pendingMessages {
80-
Task {
81-
do {
82-
try await send(message)
83-
continuation.resume()
84-
} catch {
85-
continuation.resume(throwing: error)
86-
}
87-
}
88-
}
89-
pendingMessages.removeAll()
90-
}
91-
92-
private func send(_ message: String) async throws {
44+
public func sendMessage(_ message: String) async throws {
9345
let content = ModelContent(role: "user", parts: [message])
9446
let clientContent = BidiGenerateContentClientContent(turns: [content], turnComplete: true)
9547
let clientMessage = BidiGenerateContentClientMessage.clientContent(clientContent)
9648
let clientMessageData = try jsonEncoder.encode(clientMessage)
97-
let clientMessageJSON = String(data: clientMessageData, encoding: .utf8)
98-
print("Client Message JSON: \(clientMessageJSON)")
9949
try await webSocket.send(.data(clientMessageData))
100-
setReceiveHandler()
10150
}
10251

103-
public func sendMessage(_ message: String) async throws {
104-
if state == .ready {
105-
try await send(message)
106-
} else {
107-
try await withCheckedThrowingContinuation { continuation in
108-
pendingMessages.append((message, continuation))
109-
}
52+
func openConnection() {
53+
webSocket.resume()
54+
// TODO: Verify that this task gets cancelled on deinit
55+
Task {
56+
await startEventLoop()
11057
}
11158
}
11259

113-
public func urlSession(_ session: URLSession,
114-
webSocketTask: URLSessionWebSocketTask,
115-
didOpenWithProtocol protocol: String?) {
116-
print("Web Socket opened.")
117-
118-
guard state == .connecting else {
119-
print("Web socket is not in a valid state to be opened: \(state)")
120-
return
60+
private func startEventLoop() async {
61+
defer {
62+
webSocket.cancel(with: .goingAway, reason: nil)
12163
}
12264

12365
do {
124-
let setup = BidiGenerateContentSetup(
125-
model: modelResourceName, generationConfig: generationConfig
126-
)
127-
let message = BidiGenerateContentClientMessage.setup(setup)
128-
let messageData = try jsonEncoder.encode(message)
129-
let messageJSON = String(data: messageData, encoding: .utf8)
130-
print("JSON: \(messageJSON)")
131-
webSocketTask.send(.data(messageData)) { error in
132-
if let error {
133-
print("Send Error: \(error)")
134-
self.state = .closed
135-
self.failPendingMessages(with: error)
136-
return
137-
}
138-
139-
self.state = .setupSent
140-
self.setReceiveHandler()
141-
}
142-
} catch {
143-
print(error)
144-
state = .closed
145-
failPendingMessages(with: error)
146-
}
147-
}
148-
149-
public func urlSession(_ session: URLSession,
150-
webSocketTask: URLSessionWebSocketTask,
151-
didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
152-
reason: Data?) {
153-
print("Web Socket closed.")
154-
state = .closed
155-
failPendingMessages(with: WebSocketError.connectionClosed)
156-
responseContinuation.finish()
157-
}
158-
159-
func setReceiveHandler() {
160-
guard state == .setupSent || state == .ready else {
161-
print("Web socket is not in a valid state to receive messages: \(state)")
162-
return
163-
}
66+
try await sendSetupMessage()
16467

165-
webSocket.receive { result in
166-
do {
167-
let message = try result.get()
68+
while !Task.isCancelled {
69+
let message = try await webSocket.receive()
16870
switch message {
16971
case let .string(string):
17072
print("Unexpected string response: \(string)")
171-
self.setReceiveHandler()
17273
case let .data(data):
173-
let response = try self.jsonDecoder.decode(
74+
let response = try jsonDecoder.decode(
17475
BidiGenerateContentServerMessage.self,
17576
from: data
17677
)
177-
let responseJSON = String(data: data, encoding: .utf8)
178-
179-
switch response.messageType {
180-
case .setupComplete:
181-
print("Setup Complete: \(responseJSON)")
182-
self.state = .ready
183-
self.processPendingMessages()
184-
case .serverContent:
185-
print("Server Content: \(responseJSON)")
186-
case .toolCall:
187-
// TODO: Tool calls not yet implemented
188-
print("Tool Call: \(responseJSON)")
189-
case .toolCallCancellation:
190-
// TODO: Tool call cancellation not yet implemented
191-
print("Tool Call Cancellation: \(responseJSON)")
192-
case let .goAway(goAway):
193-
if let timeLeft = goAway.timeLeft {
194-
print("Server will disconnect in \(timeLeft) seconds.")
195-
} else {
196-
print("Server will disconnect soon.")
197-
}
198-
}
199-
200-
self.responseContinuation.yield(response)
201-
202-
if self.state == .closed {
203-
print("Web socket is closed, not listening for more messages.")
204-
} else {
205-
self.setReceiveHandler()
206-
}
78+
responseContinuation.yield(response)
20779
@unknown default:
20880
print("Unknown message received")
209-
self.setReceiveHandler()
21081
}
211-
} catch {
212-
// handle the error
213-
print(error)
214-
self.state = .closed
215-
self.responseContinuation.finish(throwing: error)
21682
}
83+
} catch {
84+
responseContinuation.finish(throwing: error)
21785
}
86+
responseContinuation.finish()
87+
}
88+
89+
private func sendSetupMessage() async throws {
90+
let setup = BidiGenerateContentSetup(
91+
model: modelResourceName, generationConfig: generationConfig
92+
)
93+
let message = BidiGenerateContentClientMessage.setup(setup)
94+
let messageData = try jsonEncoder.encode(message)
95+
try await webSocket.send(.data(messageData))
21896
}
21997
}

FirebaseAI/Tests/TestApp/Sources/ContentView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct ContentView: View {
3737
.padding()
3838
.task {
3939
do {
40-
let liveSession = try await liveModel.connect()
40+
let liveSession = liveModel.connect()
4141
try await liveSession.sendMessage("Why is the sky blue?")
4242
for try await response in liveSession.responses {
4343
responses.append(String(describing: response))

0 commit comments

Comments
 (0)