diff --git a/FirebaseVertexAI/Sources/GenerationConfig.swift b/FirebaseVertexAI/Sources/GenerationConfig.swift index 5c49e60f274..125dba31fd2 100644 --- a/FirebaseVertexAI/Sources/GenerationConfig.swift +++ b/FirebaseVertexAI/Sources/GenerationConfig.swift @@ -162,4 +162,17 @@ public struct GenerationConfig { // MARK: - Codable Conformances @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension GenerationConfig: Encodable {} +extension GenerationConfig: Encodable { + enum CodingKeys: String, CodingKey { + case temperature + case topP + case topK + case candidateCount + case maxOutputTokens + case presencePenalty + case frequencyPenalty + case stopSequences + case responseMIMEType = "responseMimeType" + case responseSchema + } +} diff --git a/FirebaseVertexAI/Sources/GenerativeAIService.swift b/FirebaseVertexAI/Sources/GenerativeAIService.swift index 667819c5c76..fc35c2b258a 100644 --- a/FirebaseVertexAI/Sources/GenerativeAIService.swift +++ b/FirebaseVertexAI/Sources/GenerativeAIService.swift @@ -203,7 +203,6 @@ struct GenerativeAIService { } let encoder = JSONEncoder() - encoder.keyEncodingStrategy = .convertToSnakeCase urlRequest.httpBody = try encoder.encode(request) urlRequest.timeoutInterval = request.options.timeout diff --git a/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift b/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift index 872f394abd1..d543fb80f38 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift @@ -34,6 +34,11 @@ struct FileData: Codable, Equatable, Sendable { self.fileURI = fileURI self.mimeType = mimeType } + + enum CodingKeys: String, CodingKey { + case fileURI = "fileUri" + case mimeType + } } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift similarity index 97% rename from FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift rename to FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift index 6cf0cce9111..92bfbb2f551 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift @@ -33,7 +33,7 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable { guard container.contains(.predictions) else { images = [] raiFilteredReason = nil - // TODO: Log warning if no predictions. + // TODO(#14221): Log warning if no predictions. return } var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions) diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift new file mode 100644 index 00000000000..15b58466386 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift @@ -0,0 +1,92 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseAppCheckInterop +import FirebaseAuthInterop +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public final class ImagenModel { + /// The resource name of the model in the backend; has the format "models/model-name". + let modelResourceName: String + + /// The backing service responsible for sending and receiving model requests to the backend. + let generativeAIService: GenerativeAIService + + /// Configuration parameters for sending requests to the backend. + let requestOptions: RequestOptions + + init(name: String, + projectID: String, + apiKey: String, + requestOptions: RequestOptions, + appCheck: AppCheckInterop?, + auth: AuthInterop?, + urlSession: URLSession = .shared) { + modelResourceName = name + generativeAIService = GenerativeAIService( + projectID: projectID, + apiKey: apiKey, + appCheck: appCheck, + auth: auth, + urlSession: urlSession + ) + self.requestOptions = requestOptions + } + + public func generateImages(prompt: String) async throws + -> ImageGenerationResponse { + return try await generateImages( + prompt: prompt, + parameters: imageGenerationParameters(storageURI: nil) + ) + } + + public func generateImages(prompt: String, storageURI: String) async throws + -> ImageGenerationResponse { + return try await generateImages( + prompt: prompt, + parameters: imageGenerationParameters(storageURI: storageURI) + ) + } + + func generateImages(prompt: String, + parameters: ImageGenerationParameters) async throws + -> ImageGenerationResponse { + let request = ImageGenerationRequest( + model: modelResourceName, + options: requestOptions, + instances: [ImageGenerationInstance(prompt: prompt)], + parameters: parameters + ) + + return try await generativeAIService.loadRequest(request: request) + } + + func imageGenerationParameters(storageURI: String?) -> ImageGenerationParameters { + // TODO(#14221): Add support for configuring these parameters. + return ImageGenerationParameters( + sampleCount: 1, + storageURI: storageURI, + seed: nil, + negativePrompt: nil, + aspectRatio: nil, + safetyFilterLevel: nil, + personGeneration: nil, + outputOptions: nil, + addWatermark: nil, + includeResponsibleAIFilterReason: true + ) + } +} diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index c0cd2cb66a3..96df5b4abf5 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -104,6 +104,18 @@ public class VertexAI { ) } + public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions()) + -> ImagenModel { + return ImagenModel( + name: modelResourceName(modelName: modelName), + projectID: projectID, + apiKey: apiKey, + requestOptions: requestOptions, + appCheck: appCheck, + auth: auth + ) + } + /// Class to enable VertexAI to register via the Objective-C based Firebase component system /// to include VertexAI in the userAgent. @objc(FIRVertexAIComponent) class FirebaseVertexAIComponent: NSObject {} diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift index 8e6e6c8d601..fcb7274670f 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift @@ -41,6 +41,7 @@ final class IntegrationTests: XCTestCase { var vertex: VertexAI! var model: GenerativeModel! + var imagenModel: ImagenModel! var storage: Storage! var userID1 = "" @@ -60,6 +61,9 @@ final class IntegrationTests: XCTestCase { toolConfig: .init(functionCallingConfig: .none()), systemInstruction: systemInstruction ) + imagenModel = vertex.imagenModel( + modelName: "imagen-3.0-fast-generate-001" + ) storage = Storage.storage() } @@ -235,6 +239,30 @@ final class IntegrationTests: XCTestCase { XCTAssertTrue(String(describing: error).contains("Firebase App Check token is invalid")) } } + + // MARK: - Imagen + + func testGenerateImage_inlineData() async throws { + let imagePrompt = """ + A realistic photo of a male lion, mane thick and dark, standing proudly on a rocky outcrop + overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on + the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens. + """ + + let imageResponse = try await imagenModel.generateImages(prompt: imagePrompt) + + XCTAssertNil(imageResponse.raiFilteredReason) + XCTAssertEqual(imageResponse.images.count, 1) + let image = try XCTUnwrap(imageResponse.images.first) + + let textResponse = try await model.generateContent( + InlineDataPart(data: image.data, mimeType: "image/png"), + "What is the name of this animal? Answer with the animal name only." + ) + + let text = try XCTUnwrap(textResponse.text).trimmingCharacters(in: .whitespacesAndNewlines) + XCTAssertEqual(text, "Lion") + } } extension StorageReference { diff --git a/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift b/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift index e6bfe7cf09b..23f85e8bdbd 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift @@ -73,7 +73,7 @@ final class GenerationConfigTests: XCTestCase { "frequencyPenalty" : \(frequencyPenalty), "maxOutputTokens" : \(maxOutputTokens), "presencePenalty" : \(presencePenalty), - "responseMIMEType" : "\(responseMIMEType)", + "responseMimeType" : "\(responseMIMEType)", "responseSchema" : { "items" : { "nullable" : false, @@ -109,7 +109,7 @@ final class GenerationConfigTests: XCTestCase { let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) XCTAssertEqual(json, """ { - "responseMIMEType" : "\(mimeType)", + "responseMimeType" : "\(mimeType)", "responseSchema" : { "nullable" : false, "properties" : { diff --git a/FirebaseVertexAI/Tests/Unit/PartTests.swift b/FirebaseVertexAI/Tests/Unit/PartTests.swift index d48600e9013..aea3c1b5d92 100644 --- a/FirebaseVertexAI/Tests/Unit/PartTests.swift +++ b/FirebaseVertexAI/Tests/Unit/PartTests.swift @@ -132,7 +132,7 @@ final class PartTests: XCTestCase { XCTAssertEqual(json, """ { "fileData" : { - "fileURI" : "\(fileURI)", + "fileUri" : "\(fileURI)", "mimeType" : "\(mimeType)" } }