Skip to content

Commit 20ec9a3

Browse files
authored
[Vertex AI] Make GenerativeModel and Chat into Swift actors (#13545)
1 parent 047856b commit 20ec9a3

File tree

13 files changed

+119
-85
lines changed

13 files changed

+119
-85
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# 11.2.0
22
- [fixed] Resolved a decoding error for citations without a `uri` and added
33
support for decoding `title` fields, which were previously ignored. (#13518)
4+
- [changed] **Breaking Change**: The methods for starting streaming requests
5+
(`generateContentStream` and `sendMessageStream`) and creating a chat instance
6+
(`startChat`) are now asynchronous and must be called with `await`. (#13545)
47

58
# 10.29.0
69
- [feature] Added community support for watchOS. (#13215)

FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ struct ConversationScreen: View {
104104
}
105105

106106
private func newChat() {
107-
viewModel.startNewChat()
107+
Task {
108+
await viewModel.startNewChat()
109+
}
108110
}
109111
}
110112

FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,44 @@ class ConversationViewModel: ObservableObject {
2121
/// This array holds both the user's and the system's chat messages
2222
@Published var messages = [ChatMessage]()
2323

24-
/// Indicates we're waiting for the model to finish
25-
@Published var busy = false
24+
/// Indicates we're waiting for the model to finish or the UI is loading
25+
@Published var busy = true
2626

2727
@Published var error: Error?
2828
var hasError: Bool {
2929
return error != nil
3030
}
3131

3232
private var model: GenerativeModel
33-
private var chat: Chat
33+
private var chat: Chat? = nil
3434
private var stopGenerating = false
3535

3636
private var chatTask: Task<Void, Never>?
3737

3838
init() {
3939
model = VertexAI.vertexAI().generativeModel(modelName: "gemini-1.5-flash")
40-
chat = model.startChat()
40+
Task {
41+
await startNewChat()
42+
}
4143
}
4244

4345
func sendMessage(_ text: String, streaming: Bool = true) async {
44-
error = nil
46+
stop()
4547
if streaming {
4648
await internalSendMessageStreaming(text)
4749
} else {
4850
await internalSendMessage(text)
4951
}
5052
}
5153

52-
func startNewChat() {
54+
func startNewChat() async {
55+
busy = true
56+
defer {
57+
busy = false
58+
}
5359
stop()
54-
error = nil
55-
chat = model.startChat()
5660
messages.removeAll()
61+
chat = await model.startChat()
5762
}
5863

5964
func stop() {
@@ -62,8 +67,6 @@ class ConversationViewModel: ObservableObject {
6267
}
6368

6469
private func internalSendMessageStreaming(_ text: String) async {
65-
chatTask?.cancel()
66-
6770
chatTask = Task {
6871
busy = true
6972
defer {
@@ -79,7 +82,10 @@ class ConversationViewModel: ObservableObject {
7982
messages.append(systemMessage)
8083

8184
do {
82-
let responseStream = chat.sendMessageStream(text)
85+
guard let chat else {
86+
throw ChatError.notInitialized
87+
}
88+
let responseStream = await chat.sendMessageStream(text)
8389
for try await chunk in responseStream {
8490
messages[messages.count - 1].pending = false
8591
if let text = chunk.text {
@@ -95,8 +101,6 @@ class ConversationViewModel: ObservableObject {
95101
}
96102

97103
private func internalSendMessage(_ text: String) async {
98-
chatTask?.cancel()
99-
100104
chatTask = Task {
101105
busy = true
102106
defer {
@@ -112,10 +116,12 @@ class ConversationViewModel: ObservableObject {
112116
messages.append(systemMessage)
113117

114118
do {
115-
var response: GenerateContentResponse?
116-
response = try await chat.sendMessage(text)
119+
guard let chat = chat else {
120+
throw ChatError.notInitialized
121+
}
122+
let response = try await chat.sendMessage(text)
117123

118-
if let responseText = response?.text {
124+
if let responseText = response.text {
119125
// replace pending message with backend response
120126
messages[messages.count - 1].message = responseText
121127
messages[messages.count - 1].pending = false
@@ -127,4 +133,8 @@ class ConversationViewModel: ObservableObject {
127133
}
128134
}
129135
}
136+
137+
enum ChatError: Error {
138+
case notInitialized
139+
}
130140
}

FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ struct FunctionCallingScreen: View {
106106
}
107107

108108
private func newChat() {
109-
viewModel.startNewChat()
109+
Task {
110+
await viewModel.startNewChat()
111+
}
110112
}
111113
}
112114

FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class FunctionCallingViewModel: ObservableObject {
3333
private var functionCalls = [FunctionCall]()
3434

3535
private var model: GenerativeModel
36-
private var chat: Chat
36+
private var chat: Chat? = nil
3737

3838
private var chatTask: Task<Void, Never>?
3939

@@ -62,13 +62,13 @@ class FunctionCallingViewModel: ObservableObject {
6262
),
6363
])]
6464
)
65-
chat = model.startChat()
65+
Task {
66+
await startNewChat()
67+
}
6668
}
6769

6870
func sendMessage(_ text: String, streaming: Bool = true) async {
69-
error = nil
70-
chatTask?.cancel()
71-
71+
stop()
7272
chatTask = Task {
7373
busy = true
7474
defer {
@@ -100,11 +100,14 @@ class FunctionCallingViewModel: ObservableObject {
100100
}
101101
}
102102

103-
func startNewChat() {
103+
func startNewChat() async {
104+
busy = true
105+
defer {
106+
busy = false
107+
}
104108
stop()
105-
error = nil
106-
chat = model.startChat()
107109
messages.removeAll()
110+
chat = await model.startChat()
108111
}
109112

110113
func stop() {
@@ -114,14 +117,17 @@ class FunctionCallingViewModel: ObservableObject {
114117

115118
private func internalSendMessageStreaming(_ text: String) async throws {
116119
let functionResponses = try await processFunctionCalls()
120+
guard let chat else {
121+
throw ChatError.notInitialized
122+
}
117123
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
118124
if functionResponses.isEmpty {
119-
responseStream = chat.sendMessageStream(text)
125+
responseStream = await chat.sendMessageStream(text)
120126
} else {
121127
for functionResponse in functionResponses {
122128
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
123129
}
124-
responseStream = chat.sendMessageStream(functionResponses.modelContent())
130+
responseStream = await chat.sendMessageStream(functionResponses.modelContent())
125131
}
126132
for try await chunk in responseStream {
127133
processResponseContent(content: chunk)
@@ -130,6 +136,9 @@ class FunctionCallingViewModel: ObservableObject {
130136

131137
private func internalSendMessage(_ text: String) async throws {
132138
let functionResponses = try await processFunctionCalls()
139+
guard let chat else {
140+
throw ChatError.notInitialized
141+
}
133142
let response: GenerateContentResponse
134143
if functionResponses.isEmpty {
135144
response = try await chat.sendMessage(text)
@@ -181,6 +190,10 @@ class FunctionCallingViewModel: ObservableObject {
181190
return functionResponses
182191
}
183192

193+
enum ChatError: Error {
194+
case notInitialized
195+
}
196+
184197
// MARK: - Callable Functions
185198

186199
func getExchangeRate(args: JSONObject) -> JSONObject {

FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject {
8484
}
8585
}
8686

87-
let outputContentStream = model.generateContentStream(prompt, images)
87+
let outputContentStream = await model.generateContentStream(prompt, images)
8888

8989
// stream response
9090
for try await outputContent in outputContentStream {

FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject {
5050

5151
let prompt = "Summarize the following text for me: \(inputText)"
5252

53-
let outputContentStream = model.generateContentStream(prompt)
53+
let outputContentStream = await model.generateContentStream(prompt)
5454

5555
// stream response
5656
for try await outputContent in outputContentStream {

FirebaseVertexAI/Sources/Chat.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Foundation
1717
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
1818
/// the context in memory between each message sent.
1919
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20-
public class Chat {
20+
public actor Chat {
2121
private let model: GenerativeModel
2222

2323
/// Initializes a new chat representing a 1:1 conversation between model and user.
@@ -121,7 +121,7 @@ public class Chat {
121121

122122
// Send the history alongside the new message as context.
123123
let request = history + newContent
124-
let stream = model.generateContentStream(request)
124+
let stream = await model.generateContentStream(request)
125125
do {
126126
for try await chunk in stream {
127127
// Capture any content that's streaming. This should be populated if there's no error.

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Foundation
1919
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
2020
/// content based on various input types.
2121
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
22-
public final class GenerativeModel {
22+
public final actor GenerativeModel {
2323
/// The resource name of the model in the backend; has the format "models/model-name".
2424
let modelResourceName: String
2525

@@ -217,33 +217,31 @@ public final class GenerativeModel {
217217
isStreaming: true,
218218
options: requestOptions)
219219

220-
var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
221-
.makeAsyncIterator()
220+
let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)
221+
222222
return AsyncThrowingStream {
223-
let response: GenerateContentResponse?
224223
do {
225-
response = try await responseIterator.next()
226-
} catch {
227-
throw GenerativeModel.generateContentError(from: error)
228-
}
224+
for try await response in responseStream {
225+
// Check the prompt feedback to see if the prompt was blocked.
226+
if response.promptFeedback?.blockReason != nil {
227+
throw GenerateContentError.promptBlocked(response: response)
228+
}
229229

230-
// The responseIterator will return `nil` when it's done.
231-
guard let response = response else {
230+
// If the stream ended early unexpectedly, throw an error.
231+
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
232+
throw GenerateContentError.responseStoppedEarly(
233+
reason: finishReason,
234+
response: response
235+
)
236+
} else {
237+
// Response was valid content, pass it along and continue.
238+
return response
239+
}
240+
}
232241
// This is the end of the stream! Signal it by sending `nil`.
233242
return nil
234-
}
235-
236-
// Check the prompt feedback to see if the prompt was blocked.
237-
if response.promptFeedback?.blockReason != nil {
238-
throw GenerateContentError.promptBlocked(response: response)
239-
}
240-
241-
// If the stream ended early unexpectedly, throw an error.
242-
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
243-
throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
244-
} else {
245-
// Response was valid content, pass it along and continue.
246-
return response
243+
} catch {
244+
throw GenerativeModel.generateContentError(from: error)
247245
}
248246
}
249247
}

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,20 @@ final class ChatTests: XCTestCase {
6464
)
6565
let chat = Chat(model: model, history: [])
6666
let input = "Test input"
67-
let stream = chat.sendMessageStream(input)
67+
let stream = await chat.sendMessageStream(input)
6868

6969
// Ensure the values are parsed correctly
7070
for try await value in stream {
7171
XCTAssertNotNil(value.text)
7272
}
7373

74-
XCTAssertEqual(chat.history.count, 2)
75-
XCTAssertEqual(chat.history[0].parts[0].text, input)
74+
let history = await chat.history
75+
XCTAssertEqual(history.count, 2)
76+
XCTAssertEqual(history[0].parts[0].text, input)
7677

7778
let finalText = "1 2 3 4 5 6 7 8"
7879
let assembledExpectation = ModelContent(role: "model", parts: finalText)
79-
XCTAssertEqual(chat.history[0].parts[0].text, input)
80-
XCTAssertEqual(chat.history[1], assembledExpectation)
80+
XCTAssertEqual(history[0].parts[0].text, input)
81+
XCTAssertEqual(history[1], assembledExpectation)
8182
}
8283
}

0 commit comments

Comments
 (0)