Skip to content

Commit 60003c1

Browse files
andrewhearddaymxn
authored andcommitted
Add temporary state machine in LiveSession
1 parent aa3d148 commit 60003c1

File tree

2 files changed

+223
-4
lines changed

2 files changed

+223
-4
lines changed

FirebaseAI/Sources/Types/Public/Live/LiveGenerativeModel.swift

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,52 @@ import Foundation
1717
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
1818
public final class LiveGenerativeModel {
1919
let modelResourceName: String
20+
let firebaseInfo: FirebaseInfo
2021
let apiConfig: APIConfig
22+
let generationConfig: LiveGenerationConfig?
2123
let requestOptions: RequestOptions
24+
let urlSession: URLSession
2225

2326
init(modelResourceName: String,
2427
firebaseInfo: FirebaseInfo,
2528
apiConfig: APIConfig,
29+
generationConfig: LiveGenerationConfig? = nil,
2630
requestOptions: RequestOptions,
2731
urlSession: URLSession = GenAIURLSession.default) {
2832
self.modelResourceName = modelResourceName
33+
self.firebaseInfo = firebaseInfo
2934
self.apiConfig = apiConfig
30-
// TODO: Add LiveGenerationConfig
35+
self.generationConfig = generationConfig
3136
// TODO: Add tools
3237
// TODO: Add tool config
3338
// TODO: Add system instruction
3439
self.requestOptions = requestOptions
40+
self.urlSession = urlSession
3541
}
3642

3743
public func connect() async throws -> LiveSession {
38-
// TODO: Implement connection
39-
return LiveSession()
44+
let liveSession = LiveSession(
45+
modelResourceName: modelResourceName,
46+
generationConfig: generationConfig,
47+
url: webSocketURL(),
48+
urlSession: urlSession
49+
)
50+
print("Opening Live Session...")
51+
try await liveSession.open()
52+
return liveSession
53+
}
54+
55+
func webSocketURL() -> URL {
56+
let urlString = switch apiConfig.service {
57+
case .vertexAI:
58+
"wss://firebasevertexai.googleapis.com/ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations/us-central1?key=\(firebaseInfo.apiKey)"
59+
case .googleAI:
60+
"wss://firebasevertexai.googleapis.com/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent?key=\(firebaseInfo.apiKey)"
61+
}
62+
guard let url = URL(string: urlString) else {
63+
// TODO: Add error handling
64+
fatalError()
65+
}
66+
return url
4067
}
4168
}

FirebaseAI/Sources/Types/Public/Live/LiveSession.swift

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,196 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
public final class LiveSession {}
15+
import Foundation
16+
17+
// TODO: Extract most of this file into a service class similar to `GenerativeAIService`.
18+
public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate {
19+
private enum State {
20+
case notConnected
21+
case connecting
22+
case setupSent
23+
case ready
24+
case closed
25+
}
26+
27+
private enum WebSocketError: Error {
28+
case connectionClosed
29+
}
30+
31+
let modelResourceName: String
32+
let generationConfig: LiveGenerationConfig?
33+
let webSocket: URLSessionWebSocketTask
34+
35+
private var state: State = .notConnected
36+
private var pendingMessages: [(String, CheckedContinuation<Void, Error>)] = []
37+
private let jsonEncoder = JSONEncoder()
38+
private let jsonDecoder = JSONDecoder()
39+
40+
init(modelResourceName: String,
41+
generationConfig: LiveGenerationConfig?,
42+
url: URL,
43+
urlSession: URLSession) {
44+
self.modelResourceName = modelResourceName
45+
self.generationConfig = generationConfig
46+
webSocket = urlSession.webSocketTask(with: url)
47+
}
48+
49+
func open() async throws {
50+
guard state == .notConnected else {
51+
print("Web socket is not in a valid state to be opened: \(state)")
52+
return
53+
}
54+
55+
state = .connecting
56+
webSocket.delegate = self
57+
webSocket.resume()
58+
59+
print("Opening websocket")
60+
}
61+
62+
private func failPendingMessages(with error: Error) {
63+
for (_, continuation) in pendingMessages {
64+
continuation.resume(throwing: error)
65+
}
66+
pendingMessages.removeAll()
67+
}
68+
69+
private func processPendingMessages() {
70+
for (message, continuation) in pendingMessages {
71+
Task {
72+
do {
73+
try await send(message)
74+
continuation.resume()
75+
} catch {
76+
continuation.resume(throwing: error)
77+
}
78+
}
79+
}
80+
pendingMessages.removeAll()
81+
}
82+
83+
private func send(_ message: String) async throws {
84+
let content = ModelContent(role: "user", parts: [message])
85+
let clientContent = BidiGenerateContentClientContent(turns: [content], turnComplete: true)
86+
let clientMessage = BidiGenerateContentClientMessage.clientContent(clientContent)
87+
let clientMessageData = try jsonEncoder.encode(clientMessage)
88+
let clientMessageJSON = String(data: clientMessageData, encoding: .utf8)
89+
print("Client Message JSON: \(clientMessageJSON)")
90+
try await webSocket.send(.data(clientMessageData))
91+
setReceiveHandler()
92+
}
93+
94+
public func sendMessage(_ message: String) async throws {
95+
if state == .ready {
96+
try await send(message)
97+
} else {
98+
try await withCheckedThrowingContinuation { continuation in
99+
pendingMessages.append((message, continuation))
100+
}
101+
}
102+
}
103+
104+
public func urlSession(_ session: URLSession,
105+
webSocketTask: URLSessionWebSocketTask,
106+
didOpenWithProtocol protocol: String?) {
107+
print("Web Socket opened.")
108+
109+
guard state == .connecting else {
110+
print("Web socket is not in a valid state to be opened: \(state)")
111+
return
112+
}
113+
114+
do {
115+
let setup = BidiGenerateContentSetup(
116+
model: modelResourceName, generationConfig: generationConfig
117+
)
118+
let message = BidiGenerateContentClientMessage.setup(setup)
119+
let messageData = try jsonEncoder.encode(message)
120+
let messageJSON = String(data: messageData, encoding: .utf8)
121+
print("JSON: \(messageJSON)")
122+
webSocketTask.send(.data(messageData)) { error in
123+
if let error {
124+
print("Send Error: \(error)")
125+
self.state = .closed
126+
self.failPendingMessages(with: error)
127+
return
128+
}
129+
130+
self.state = .setupSent
131+
self.setReceiveHandler()
132+
}
133+
} catch {
134+
print(error)
135+
state = .closed
136+
failPendingMessages(with: error)
137+
}
138+
}
139+
140+
public func urlSession(_ session: URLSession,
141+
webSocketTask: URLSessionWebSocketTask,
142+
didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
143+
reason: Data?) {
144+
print("Web Socket closed.")
145+
state = .closed
146+
failPendingMessages(with: WebSocketError.connectionClosed)
147+
}
148+
149+
func setReceiveHandler() {
150+
guard state == .setupSent || state == .ready else {
151+
print("Web socket is not in a valid state to receive messages: \(state)")
152+
return
153+
}
154+
155+
webSocket.receive { result in
156+
do {
157+
let message = try result.get()
158+
switch message {
159+
case let .string(string):
160+
print("Unexpected string response: \(string)")
161+
self.setReceiveHandler()
162+
case let .data(data):
163+
let response = try self.jsonDecoder.decode(
164+
BidiGenerateContentServerMessage.self,
165+
from: data
166+
)
167+
let responseJSON = String(data: data, encoding: .utf8)
168+
169+
switch response.messageType {
170+
case .setupComplete:
171+
print("Setup Complete: \(responseJSON)")
172+
self.state = .ready
173+
self.processPendingMessages()
174+
case .serverContent:
175+
// TODO: Return the serverContent to the developer
176+
print("Server Content: \(responseJSON)")
177+
case .toolCall:
178+
// TODO: Tool calls not yet implemented
179+
print("Tool Call: \(responseJSON)")
180+
case .toolCallCancellation:
181+
// TODO: Tool call cancellation not yet implemented
182+
print("Tool Call Cancellation: \(responseJSON)")
183+
case let .goAway(goAway):
184+
if let timeLeft = goAway.timeLeft {
185+
print("Server will disconnect in \(timeLeft) seconds.")
186+
} else {
187+
print("Server will disconnect soon.")
188+
}
189+
}
190+
191+
if self.state == .closed {
192+
print("Web socket is closed, not listening for more messages.")
193+
} else {
194+
self.setReceiveHandler()
195+
}
196+
@unknown default:
197+
print("Unknown message received")
198+
self.setReceiveHandler()
199+
}
200+
} catch {
201+
// handle the error
202+
print(error)
203+
self.state = .closed
204+
}
205+
}
206+
}
207+
}

0 commit comments

Comments
 (0)