diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index c048c49ebc4..fc54c727370 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -1,5 +1,9 @@ -# Unreleased +# 11.3.0 - [added] Added `Decodable` conformance for `FunctionResponse`. (#13606) +- [changed] **Breaking Change**: Reverted refactor of `GenerativeModel` and + `Chat` as Swift actors (#13545) introduced in 11.2; The methods + `generateContentStream`, `startChat` and `sendMessageStream` no longer need to + be called with `await`. (#13703) # 11.2.0 - [fixed] Resolved a decoding error for citations without a `uri` and added diff --git a/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift b/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift index 43da223aa78..78c903e3412 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift @@ -104,9 +104,7 @@ struct ConversationScreen: View { } private func newChat() { - Task { - await viewModel.startNewChat() - } + viewModel.startNewChat() } } diff --git a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift index 3db45eb8d19..65e1c940de9 100644 --- a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift @@ -21,8 +21,8 @@ class ConversationViewModel: ObservableObject { /// This array holds both the user's and the system's chat messages @Published var messages = [ChatMessage]() - /// Indicates we're waiting for the model to finish or the UI is loading - @Published var busy = true + /// Indicates we're waiting for the model to finish + @Published var busy = false @Published var error: Error? var hasError: Bool { @@ -30,20 +30,18 @@ class ConversationViewModel: ObservableObject { } private var model: GenerativeModel - private var chat: Chat? = nil + private var chat: Chat private var stopGenerating = false private var chatTask: Task? init() { model = VertexAI.vertexAI().generativeModel(modelName: "gemini-1.5-flash") - Task { - await startNewChat() - } + chat = model.startChat() } func sendMessage(_ text: String, streaming: Bool = true) async { - stop() + error = nil if streaming { await internalSendMessageStreaming(text) } else { @@ -51,14 +49,11 @@ class ConversationViewModel: ObservableObject { } } - func startNewChat() async { - busy = true - defer { - busy = false - } + func startNewChat() { stop() + error = nil + chat = model.startChat() messages.removeAll() - chat = await model.startChat() } func stop() { @@ -67,6 +62,8 @@ class ConversationViewModel: ObservableObject { } private func internalSendMessageStreaming(_ text: String) async { + chatTask?.cancel() + chatTask = Task { busy = true defer { @@ -82,10 +79,7 @@ class ConversationViewModel: ObservableObject { messages.append(systemMessage) do { - guard let chat else { - throw ChatError.notInitialized - } - let responseStream = try await chat.sendMessageStream(text) + let responseStream = try chat.sendMessageStream(text) for try await chunk in responseStream { messages[messages.count - 1].pending = false if let text = chunk.text { @@ -101,6 +95,8 @@ class ConversationViewModel: ObservableObject { } private func internalSendMessage(_ text: String) async { + chatTask?.cancel() + chatTask = Task { busy = true defer { @@ -116,12 +112,10 @@ class ConversationViewModel: ObservableObject { messages.append(systemMessage) do { - guard let chat = chat else { - throw ChatError.notInitialized - } - let response = try await chat.sendMessage(text) + var response: GenerateContentResponse? + response = try await chat.sendMessage(text) - if let responseText = response.text { + if let responseText = response?.text { // replace pending message with backend response messages[messages.count - 1].message = responseText messages[messages.count - 1].pending = false @@ -133,8 +127,4 @@ class ConversationViewModel: ObservableObject { } } } - - enum ChatError: Error { - case notInitialized - } } diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift index dbfd04eb52c..f16da39e22f 100644 --- a/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift @@ -106,9 +106,7 @@ struct FunctionCallingScreen: View { } private func newChat() { - Task { - await viewModel.startNewChat() - } + viewModel.startNewChat() } } diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift index 36b53f2e2da..ac2ea5a1fcc 100644 --- a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -33,7 +33,7 @@ class FunctionCallingViewModel: ObservableObject { private var functionCalls = [FunctionCall]() private var model: GenerativeModel - private var chat: Chat? = nil + private var chat: Chat private var chatTask: Task? @@ -62,13 +62,13 @@ class FunctionCallingViewModel: ObservableObject { ), ])] ) - Task { - await startNewChat() - } + chat = model.startChat() } func sendMessage(_ text: String, streaming: Bool = true) async { - stop() + error = nil + chatTask?.cancel() + chatTask = Task { busy = true defer { @@ -100,14 +100,11 @@ class FunctionCallingViewModel: ObservableObject { } } - func startNewChat() async { - busy = true - defer { - busy = false - } + func startNewChat() { stop() + error = nil + chat = model.startChat() messages.removeAll() - chat = await model.startChat() } func stop() { @@ -117,17 +114,14 @@ class FunctionCallingViewModel: ObservableObject { private func internalSendMessageStreaming(_ text: String) async throws { let functionResponses = try await processFunctionCalls() - guard let chat else { - throw ChatError.notInitialized - } let responseStream: AsyncThrowingStream if functionResponses.isEmpty { - responseStream = try await chat.sendMessageStream(text) + responseStream = try chat.sendMessageStream(text) } else { for functionResponse in functionResponses { messages.insert(functionResponse.chatMessage(), at: messages.count - 1) } - responseStream = try await chat.sendMessageStream(functionResponses.modelContent()) + responseStream = try chat.sendMessageStream(functionResponses.modelContent()) } for try await chunk in responseStream { processResponseContent(content: chunk) @@ -136,9 +130,6 @@ class FunctionCallingViewModel: ObservableObject { private func internalSendMessage(_ text: String) async throws { let functionResponses = try await processFunctionCalls() - guard let chat else { - throw ChatError.notInitialized - } let response: GenerateContentResponse if functionResponses.isEmpty { response = try await chat.sendMessage(text) @@ -190,10 +181,6 @@ class FunctionCallingViewModel: ObservableObject { return functionResponses } - enum ChatError: Error { - case notInitialized - } - // MARK: - Callable Functions func getExchangeRate(args: JSONObject) -> JSONObject { diff --git a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index 4eee954a68d..3a85da05102 100644 --- a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject { } } - let outputContentStream = try await model.generateContentStream(prompt, images) + let outputContentStream = try model.generateContentStream(prompt, images) // stream response for try await outputContent in outputContentStream { diff --git a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index 4467b85fe3d..540ff893ba8 100644 --- a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject { let prompt = "Summarize the following text for me: \(inputText)" - let outputContentStream = try await model.generateContentStream(prompt) + let outputContentStream = try model.generateContentStream(prompt) // stream response for try await outputContent in outputContentStream { diff --git a/FirebaseVertexAI/Sources/Chat.swift b/FirebaseVertexAI/Sources/Chat.swift index 0c0239ed54e..ff7a8aa9a09 100644 --- a/FirebaseVertexAI/Sources/Chat.swift +++ b/FirebaseVertexAI/Sources/Chat.swift @@ -17,7 +17,7 @@ import Foundation /// An object that represents a back-and-forth chat with a model, capturing the history and saving /// the context in memory between each message sent. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public actor Chat { +public class Chat { private let model: GenerativeModel /// Initializes a new chat representing a 1:1 conversation between model and user. @@ -116,7 +116,7 @@ public actor Chat { // Send the history alongside the new message as context. let request = history + newContent - let stream = try await model.generateContentStream(request) + let stream = try model.generateContentStream(request) do { for try await chunk in stream { // Capture any content that's streaming. This should be populated if there's no error. diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index 64c97729177..28d3ca4ba88 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -19,7 +19,7 @@ import Foundation /// A type that represents a remote multimodal model (like Gemini), with the ability to generate /// content based on various input types. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public final actor GenerativeModel { +public final class GenerativeModel { /// The resource name of the model in the backend; has the format "models/model-name". let modelResourceName: String @@ -212,31 +212,33 @@ public final actor GenerativeModel { isStreaming: true, options: requestOptions) - let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest) - + var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) + .makeAsyncIterator() return AsyncThrowingStream { + let response: GenerateContentResponse? do { - for try await response in responseStream { - // Check the prompt feedback to see if the prompt was blocked. - if response.promptFeedback?.blockReason != nil { - throw GenerateContentError.promptBlocked(response: response) - } + response = try await responseIterator.next() + } catch { + throw GenerativeModel.generateContentError(from: error) + } - // If the stream ended early unexpectedly, throw an error. - if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { - throw GenerateContentError.responseStoppedEarly( - reason: finishReason, - response: response - ) - } else { - // Response was valid content, pass it along and continue. - return response - } - } + // The responseIterator will return `nil` when it's done. + guard let response = response else { // This is the end of the stream! Signal it by sending `nil`. return nil - } catch { - throw GenerativeModel.generateContentError(from: error) + } + + // Check the prompt feedback to see if the prompt was blocked. + if response.promptFeedback?.blockReason != nil { + throw GenerateContentError.promptBlocked(response: response) + } + + // If the stream ended early unexpectedly, throw an error. + if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { + throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response) + } else { + // Response was valid content, pass it along and continue. + return response } } } diff --git a/FirebaseVertexAI/Tests/Unit/ChatTests.swift b/FirebaseVertexAI/Tests/Unit/ChatTests.swift index 2e55b73020f..614559fe011 100644 --- a/FirebaseVertexAI/Tests/Unit/ChatTests.swift +++ b/FirebaseVertexAI/Tests/Unit/ChatTests.swift @@ -69,20 +69,19 @@ final class ChatTests: XCTestCase { ) let chat = Chat(model: model, history: []) let input = "Test input" - let stream = try await chat.sendMessageStream(input) + let stream = try chat.sendMessageStream(input) // Ensure the values are parsed correctly for try await value in stream { XCTAssertNotNil(value.text) } - let history = await chat.history - XCTAssertEqual(history.count, 2) - XCTAssertEqual(history[0].parts[0].text, input) + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].parts[0].text, input) let finalText = "1 2 3 4 5 6 7 8" let assembledExpectation = ModelContent(role: "model", parts: finalText) - XCTAssertEqual(history[0].parts[0].text, input) - XCTAssertEqual(history[1], assembledExpectation) + XCTAssertEqual(chat.history[0].parts[0].text, input) + XCTAssertEqual(chat.history[1], assembledExpectation) } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index b01d62b90f0..dc76123d028 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -760,7 +760,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } @@ -784,7 +784,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } @@ -807,7 +807,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } @@ -827,7 +827,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } @@ -847,7 +847,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } @@ -866,7 +866,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") do { for try await content in stream { XCTAssertNotNil(content.text) @@ -887,7 +887,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = 0 - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) responses += 1 @@ -904,7 +904,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = 0 - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) responses += 1 @@ -921,7 +921,7 @@ final class GenerativeModelTests: XCTestCase { ) var hadUnknown = false - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) if let ratings = content.candidates.first?.safetyRatings, @@ -940,7 +940,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") var citations = [Citation]() var responses = [GenerateContentResponse]() for try await content in stream { @@ -996,7 +996,7 @@ final class GenerativeModelTests: XCTestCase { appCheckToken: appCheckToken ) - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) for try await _ in stream {} } @@ -1018,7 +1018,7 @@ final class GenerativeModelTests: XCTestCase { appCheckToken: AppCheckInteropFake.placeholderTokenValue ) - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) for try await _ in stream {} } @@ -1030,7 +1030,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = [GenerateContentResponse]() - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) for try await response in stream { responses.append(response) } @@ -1056,7 +1056,7 @@ final class GenerativeModelTests: XCTestCase { var responseCount = 0 do { - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) responseCount += 1 @@ -1076,7 +1076,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContentStream_nonHTTPResponse() async throws { MockURLProtocol.requestHandler = try nonHTTPRequestHandler() - let stream = try await model.generateContentStream("Hi") + let stream = try model.generateContentStream("Hi") do { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") @@ -1096,7 +1096,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) do { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") @@ -1120,7 +1120,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) do { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") @@ -1159,7 +1159,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = 0 - let stream = try await model.generateContentStream(testPrompt) + let stream = try model.generateContentStream(testPrompt) for try await content in stream { XCTAssertNotNil(content.text) responses += 1 diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index f2c38a03e61..c68b69b03ec 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -170,8 +170,8 @@ final class VertexAIAPITests: XCTestCase { #endif // Chat - _ = await genAI.startChat() - _ = await genAI.startChat(history: [ModelContent(parts: "abc")]) + _ = genAI.startChat() + _ = genAI.startChat(history: [ModelContent(parts: "abc")]) } // Public API tests for GenerateContentResponse. diff --git a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift index a6f77467c24..64a3edcc6b8 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift @@ -107,20 +107,15 @@ class VertexComponentTests: XCTestCase { let app = try XCTUnwrap(VertexComponentTests.app) let vertex = VertexAI.vertexAI(app: app, location: location) let modelName = "test-model-name" - let expectedModelResourceName = vertex.modelResourceName(modelName: modelName) - let expectedSystemInstruction = ModelContent( - role: "system", - parts: "test-system-instruction-prompt" - ) + let modelResourceName = vertex.modelResourceName(modelName: modelName) + let systemInstruction = ModelContent(role: "system", parts: "test-system-instruction-prompt") let generativeModel = vertex.generativeModel( modelName: modelName, - systemInstruction: expectedSystemInstruction + systemInstruction: systemInstruction ) - let modelResourceName = await generativeModel.modelResourceName - let systemInstruction = await generativeModel.systemInstruction - XCTAssertEqual(modelResourceName, expectedModelResourceName) - XCTAssertEqual(systemInstruction, expectedSystemInstruction) + XCTAssertEqual(generativeModel.modelResourceName, modelResourceName) + XCTAssertEqual(generativeModel.systemInstruction, systemInstruction) } }