Skip to content

Commit 2bff4d4

Browse files
authored
Revert "Make GenerativeModel and Chat into Swift actors (#13545)" (#13703)
1 parent e084de1 commit 2bff4d4

File tree

13 files changed

+90
-117
lines changed

13 files changed

+90
-117
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
# Unreleased
1+
# 11.3.0
22
- [added] Added `Decodable` conformance for `FunctionResponse`. (#13606)
3+
- [changed] **Breaking Change**: Reverted refactor of `GenerativeModel` and
4+
`Chat` as Swift actors (#13545) introduced in 11.2; The methods
5+
`generateContentStream`, `startChat` and `sendMessageStream` no longer need to
6+
be called with `await`. (#13703)
37

48
# 11.2.0
59
- [fixed] Resolved a decoding error for citations without a `uri` and added

FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,7 @@ struct ConversationScreen: View {
104104
}
105105

106106
private func newChat() {
107-
Task {
108-
await viewModel.startNewChat()
109-
}
107+
viewModel.startNewChat()
110108
}
111109
}
112110

FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,39 @@ class ConversationViewModel: ObservableObject {
2121
/// This array holds both the user's and the system's chat messages
2222
@Published var messages = [ChatMessage]()
2323

24-
/// Indicates we're waiting for the model to finish or the UI is loading
25-
@Published var busy = true
24+
/// Indicates we're waiting for the model to finish
25+
@Published var busy = false
2626

2727
@Published var error: Error?
2828
var hasError: Bool {
2929
return error != nil
3030
}
3131

3232
private var model: GenerativeModel
33-
private var chat: Chat? = nil
33+
private var chat: Chat
3434
private var stopGenerating = false
3535

3636
private var chatTask: Task<Void, Never>?
3737

3838
init() {
3939
model = VertexAI.vertexAI().generativeModel(modelName: "gemini-1.5-flash")
40-
Task {
41-
await startNewChat()
42-
}
40+
chat = model.startChat()
4341
}
4442

4543
func sendMessage(_ text: String, streaming: Bool = true) async {
46-
stop()
44+
error = nil
4745
if streaming {
4846
await internalSendMessageStreaming(text)
4947
} else {
5048
await internalSendMessage(text)
5149
}
5250
}
5351

54-
func startNewChat() async {
55-
busy = true
56-
defer {
57-
busy = false
58-
}
52+
func startNewChat() {
5953
stop()
54+
error = nil
55+
chat = model.startChat()
6056
messages.removeAll()
61-
chat = await model.startChat()
6257
}
6358

6459
func stop() {
@@ -67,6 +62,8 @@ class ConversationViewModel: ObservableObject {
6762
}
6863

6964
private func internalSendMessageStreaming(_ text: String) async {
65+
chatTask?.cancel()
66+
7067
chatTask = Task {
7168
busy = true
7269
defer {
@@ -82,10 +79,7 @@ class ConversationViewModel: ObservableObject {
8279
messages.append(systemMessage)
8380

8481
do {
85-
guard let chat else {
86-
throw ChatError.notInitialized
87-
}
88-
let responseStream = try await chat.sendMessageStream(text)
82+
let responseStream = try chat.sendMessageStream(text)
8983
for try await chunk in responseStream {
9084
messages[messages.count - 1].pending = false
9185
if let text = chunk.text {
@@ -101,6 +95,8 @@ class ConversationViewModel: ObservableObject {
10195
}
10296

10397
private func internalSendMessage(_ text: String) async {
98+
chatTask?.cancel()
99+
104100
chatTask = Task {
105101
busy = true
106102
defer {
@@ -116,12 +112,10 @@ class ConversationViewModel: ObservableObject {
116112
messages.append(systemMessage)
117113

118114
do {
119-
guard let chat = chat else {
120-
throw ChatError.notInitialized
121-
}
122-
let response = try await chat.sendMessage(text)
115+
var response: GenerateContentResponse?
116+
response = try await chat.sendMessage(text)
123117

124-
if let responseText = response.text {
118+
if let responseText = response?.text {
125119
// replace pending message with backend response
126120
messages[messages.count - 1].message = responseText
127121
messages[messages.count - 1].pending = false
@@ -133,8 +127,4 @@ class ConversationViewModel: ObservableObject {
133127
}
134128
}
135129
}
136-
137-
enum ChatError: Error {
138-
case notInitialized
139-
}
140130
}

FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ struct FunctionCallingScreen: View {
106106
}
107107

108108
private func newChat() {
109-
Task {
110-
await viewModel.startNewChat()
111-
}
109+
viewModel.startNewChat()
112110
}
113111
}
114112

FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class FunctionCallingViewModel: ObservableObject {
3333
private var functionCalls = [FunctionCall]()
3434

3535
private var model: GenerativeModel
36-
private var chat: Chat? = nil
36+
private var chat: Chat
3737

3838
private var chatTask: Task<Void, Never>?
3939

@@ -62,13 +62,13 @@ class FunctionCallingViewModel: ObservableObject {
6262
),
6363
])]
6464
)
65-
Task {
66-
await startNewChat()
67-
}
65+
chat = model.startChat()
6866
}
6967

7068
func sendMessage(_ text: String, streaming: Bool = true) async {
71-
stop()
69+
error = nil
70+
chatTask?.cancel()
71+
7272
chatTask = Task {
7373
busy = true
7474
defer {
@@ -100,14 +100,11 @@ class FunctionCallingViewModel: ObservableObject {
100100
}
101101
}
102102

103-
func startNewChat() async {
104-
busy = true
105-
defer {
106-
busy = false
107-
}
103+
func startNewChat() {
108104
stop()
105+
error = nil
106+
chat = model.startChat()
109107
messages.removeAll()
110-
chat = await model.startChat()
111108
}
112109

113110
func stop() {
@@ -117,17 +114,14 @@ class FunctionCallingViewModel: ObservableObject {
117114

118115
private func internalSendMessageStreaming(_ text: String) async throws {
119116
let functionResponses = try await processFunctionCalls()
120-
guard let chat else {
121-
throw ChatError.notInitialized
122-
}
123117
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
124118
if functionResponses.isEmpty {
125-
responseStream = try await chat.sendMessageStream(text)
119+
responseStream = try chat.sendMessageStream(text)
126120
} else {
127121
for functionResponse in functionResponses {
128122
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
129123
}
130-
responseStream = try await chat.sendMessageStream(functionResponses.modelContent())
124+
responseStream = try chat.sendMessageStream(functionResponses.modelContent())
131125
}
132126
for try await chunk in responseStream {
133127
processResponseContent(content: chunk)
@@ -136,9 +130,6 @@ class FunctionCallingViewModel: ObservableObject {
136130

137131
private func internalSendMessage(_ text: String) async throws {
138132
let functionResponses = try await processFunctionCalls()
139-
guard let chat else {
140-
throw ChatError.notInitialized
141-
}
142133
let response: GenerateContentResponse
143134
if functionResponses.isEmpty {
144135
response = try await chat.sendMessage(text)
@@ -190,10 +181,6 @@ class FunctionCallingViewModel: ObservableObject {
190181
return functionResponses
191182
}
192183

193-
enum ChatError: Error {
194-
case notInitialized
195-
}
196-
197184
// MARK: - Callable Functions
198185

199186
func getExchangeRate(args: JSONObject) -> JSONObject {

FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject {
8484
}
8585
}
8686

87-
let outputContentStream = try await model.generateContentStream(prompt, images)
87+
let outputContentStream = try model.generateContentStream(prompt, images)
8888

8989
// stream response
9090
for try await outputContent in outputContentStream {

FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject {
5050

5151
let prompt = "Summarize the following text for me: \(inputText)"
5252

53-
let outputContentStream = try await model.generateContentStream(prompt)
53+
let outputContentStream = try model.generateContentStream(prompt)
5454

5555
// stream response
5656
for try await outputContent in outputContentStream {

FirebaseVertexAI/Sources/Chat.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Foundation
1717
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
1818
/// the context in memory between each message sent.
1919
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20-
public actor Chat {
20+
public class Chat {
2121
private let model: GenerativeModel
2222

2323
/// Initializes a new chat representing a 1:1 conversation between model and user.
@@ -116,7 +116,7 @@ public actor Chat {
116116

117117
// Send the history alongside the new message as context.
118118
let request = history + newContent
119-
let stream = try await model.generateContentStream(request)
119+
let stream = try model.generateContentStream(request)
120120
do {
121121
for try await chunk in stream {
122122
// Capture any content that's streaming. This should be populated if there's no error.

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Foundation
1919
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
2020
/// content based on various input types.
2121
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
22-
public final actor GenerativeModel {
22+
public final class GenerativeModel {
2323
/// The resource name of the model in the backend; has the format "models/model-name".
2424
let modelResourceName: String
2525

@@ -212,31 +212,33 @@ public final actor GenerativeModel {
212212
isStreaming: true,
213213
options: requestOptions)
214214

215-
let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest)
216-
215+
var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
216+
.makeAsyncIterator()
217217
return AsyncThrowingStream {
218+
let response: GenerateContentResponse?
218219
do {
219-
for try await response in responseStream {
220-
// Check the prompt feedback to see if the prompt was blocked.
221-
if response.promptFeedback?.blockReason != nil {
222-
throw GenerateContentError.promptBlocked(response: response)
223-
}
220+
response = try await responseIterator.next()
221+
} catch {
222+
throw GenerativeModel.generateContentError(from: error)
223+
}
224224

225-
// If the stream ended early unexpectedly, throw an error.
226-
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
227-
throw GenerateContentError.responseStoppedEarly(
228-
reason: finishReason,
229-
response: response
230-
)
231-
} else {
232-
// Response was valid content, pass it along and continue.
233-
return response
234-
}
235-
}
225+
// The responseIterator will return `nil` when it's done.
226+
guard let response = response else {
236227
// This is the end of the stream! Signal it by sending `nil`.
237228
return nil
238-
} catch {
239-
throw GenerativeModel.generateContentError(from: error)
229+
}
230+
231+
// Check the prompt feedback to see if the prompt was blocked.
232+
if response.promptFeedback?.blockReason != nil {
233+
throw GenerateContentError.promptBlocked(response: response)
234+
}
235+
236+
// If the stream ended early unexpectedly, throw an error.
237+
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
238+
throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
239+
} else {
240+
// Response was valid content, pass it along and continue.
241+
return response
240242
}
241243
}
242244
}

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,19 @@ final class ChatTests: XCTestCase {
6969
)
7070
let chat = Chat(model: model, history: [])
7171
let input = "Test input"
72-
let stream = try await chat.sendMessageStream(input)
72+
let stream = try chat.sendMessageStream(input)
7373

7474
// Ensure the values are parsed correctly
7575
for try await value in stream {
7676
XCTAssertNotNil(value.text)
7777
}
7878

79-
let history = await chat.history
80-
XCTAssertEqual(history.count, 2)
81-
XCTAssertEqual(history[0].parts[0].text, input)
79+
XCTAssertEqual(chat.history.count, 2)
80+
XCTAssertEqual(chat.history[0].parts[0].text, input)
8281

8382
let finalText = "1 2 3 4 5 6 7 8"
8483
let assembledExpectation = ModelContent(role: "model", parts: finalText)
85-
XCTAssertEqual(history[0].parts[0].text, input)
86-
XCTAssertEqual(history[1], assembledExpectation)
84+
XCTAssertEqual(chat.history[0].parts[0].text, input)
85+
XCTAssertEqual(chat.history[1], assembledExpectation)
8786
}
8887
}

0 commit comments

Comments
 (0)