Skip to content

Commit fdfaea2

Browse files
committed
[Vertex AI] Add ImagenModel with generateImages functions (#14226)
1 parent d7e960c commit fdfaea2

File tree

9 files changed

+155
-6
lines changed

9 files changed

+155
-6
lines changed

FirebaseVertexAI/Sources/GenerationConfig.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,17 @@ public struct GenerationConfig {
162162
// MARK: - Codable Conformances
163163

164164
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
165-
extension GenerationConfig: Encodable {}
165+
extension GenerationConfig: Encodable {
166+
enum CodingKeys: String, CodingKey {
167+
case temperature
168+
case topP
169+
case topK
170+
case candidateCount
171+
case maxOutputTokens
172+
case presencePenalty
173+
case frequencyPenalty
174+
case stopSequences
175+
case responseMIMEType = "responseMimeType"
176+
case responseSchema
177+
}
178+
}

FirebaseVertexAI/Sources/GenerativeAIService.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ struct GenerativeAIService {
203203
}
204204

205205
let encoder = JSONEncoder()
206-
encoder.keyEncodingStrategy = .convertToSnakeCase
207206
urlRequest.httpBody = try encoder.encode(request)
208207
urlRequest.timeoutInterval = request.options.timeout
209208

FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ struct FileData: Codable, Equatable, Sendable {
3434
self.fileURI = fileURI
3535
self.mimeType = mimeType
3636
}
37+
38+
enum CodingKeys: String, CodingKey {
39+
case fileURI = "fileUri"
40+
case mimeType
41+
}
3742
}
3843

3944
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift renamed to FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable {
3333
guard container.contains(.predictions) else {
3434
images = []
3535
raiFilteredReason = nil
36-
// TODO: Log warning if no predictions.
36+
// TODO(#14221): Log warning if no predictions.
3737
return
3838
}
3939
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import FirebaseAppCheckInterop
16+
import FirebaseAuthInterop
17+
import Foundation
18+
19+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20+
public final class ImagenModel {
21+
/// The resource name of the model in the backend; has the format "models/model-name".
22+
let modelResourceName: String
23+
24+
/// The backing service responsible for sending and receiving model requests to the backend.
25+
let generativeAIService: GenerativeAIService
26+
27+
/// Configuration parameters for sending requests to the backend.
28+
let requestOptions: RequestOptions
29+
30+
init(name: String,
31+
projectID: String,
32+
apiKey: String,
33+
requestOptions: RequestOptions,
34+
appCheck: AppCheckInterop?,
35+
auth: AuthInterop?,
36+
urlSession: URLSession = .shared) {
37+
modelResourceName = name
38+
generativeAIService = GenerativeAIService(
39+
projectID: projectID,
40+
apiKey: apiKey,
41+
appCheck: appCheck,
42+
auth: auth,
43+
urlSession: urlSession
44+
)
45+
self.requestOptions = requestOptions
46+
}
47+
48+
public func generateImages(prompt: String) async throws
49+
-> ImageGenerationResponse<ImagenInlineDataImage> {
50+
return try await generateImages(
51+
prompt: prompt,
52+
parameters: imageGenerationParameters(storageURI: nil)
53+
)
54+
}
55+
56+
public func generateImages(prompt: String, storageURI: String) async throws
57+
-> ImageGenerationResponse<ImagenFileDataImage> {
58+
return try await generateImages(
59+
prompt: prompt,
60+
parameters: imageGenerationParameters(storageURI: storageURI)
61+
)
62+
}
63+
64+
func generateImages<T: Decodable>(prompt: String,
65+
parameters: ImageGenerationParameters) async throws
66+
-> ImageGenerationResponse<T> {
67+
let request = ImageGenerationRequest<T>(
68+
model: modelResourceName,
69+
options: requestOptions,
70+
instances: [ImageGenerationInstance(prompt: prompt)],
71+
parameters: parameters
72+
)
73+
74+
return try await generativeAIService.loadRequest(request: request)
75+
}
76+
77+
func imageGenerationParameters(storageURI: String?) -> ImageGenerationParameters {
78+
// TODO(#14221): Add support for configuring these parameters.
79+
return ImageGenerationParameters(
80+
sampleCount: 1,
81+
storageURI: storageURI,
82+
seed: nil,
83+
negativePrompt: nil,
84+
aspectRatio: nil,
85+
safetyFilterLevel: nil,
86+
personGeneration: nil,
87+
outputOptions: nil,
88+
addWatermark: nil,
89+
includeResponsibleAIFilterReason: true
90+
)
91+
}
92+
}

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ public class VertexAI {
104104
)
105105
}
106106

107+
public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions())
108+
-> ImagenModel {
109+
return ImagenModel(
110+
name: modelResourceName(modelName: modelName),
111+
projectID: projectID,
112+
apiKey: apiKey,
113+
requestOptions: requestOptions,
114+
appCheck: appCheck,
115+
auth: auth
116+
)
117+
}
118+
107119
/// Class to enable VertexAI to register via the Objective-C based Firebase component system
108120
/// to include VertexAI in the userAgent.
109121
@objc(FIRVertexAIComponent) class FirebaseVertexAIComponent: NSObject {}

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ final class IntegrationTests: XCTestCase {
4141

4242
var vertex: VertexAI!
4343
var model: GenerativeModel!
44+
var imagenModel: ImagenModel!
4445
var storage: Storage!
4546
var userID1 = ""
4647

@@ -60,6 +61,9 @@ final class IntegrationTests: XCTestCase {
6061
toolConfig: .init(functionCallingConfig: .none()),
6162
systemInstruction: systemInstruction
6263
)
64+
imagenModel = vertex.imagenModel(
65+
modelName: "imagen-3.0-fast-generate-001"
66+
)
6367

6468
storage = Storage.storage()
6569
}
@@ -235,6 +239,30 @@ final class IntegrationTests: XCTestCase {
235239
XCTAssertTrue(String(describing: error).contains("Firebase App Check token is invalid"))
236240
}
237241
}
242+
243+
// MARK: - Imagen
244+
245+
func testGenerateImage_inlineData() async throws {
246+
let imagePrompt = """
247+
A realistic photo of a male lion, mane thick and dark, standing proudly on a rocky outcrop
248+
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
249+
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
250+
"""
251+
252+
let imageResponse = try await imagenModel.generateImages(prompt: imagePrompt)
253+
254+
XCTAssertNil(imageResponse.raiFilteredReason)
255+
XCTAssertEqual(imageResponse.images.count, 1)
256+
let image = try XCTUnwrap(imageResponse.images.first)
257+
258+
let textResponse = try await model.generateContent(
259+
InlineDataPart(data: image.data, mimeType: "image/png"),
260+
"What is the name of this animal? Answer with the animal name only."
261+
)
262+
263+
let text = try XCTUnwrap(textResponse.text).trimmingCharacters(in: .whitespacesAndNewlines)
264+
XCTAssertEqual(text, "Lion")
265+
}
238266
}
239267

240268
extension StorageReference {

FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ final class GenerationConfigTests: XCTestCase {
7373
"frequencyPenalty" : \(frequencyPenalty),
7474
"maxOutputTokens" : \(maxOutputTokens),
7575
"presencePenalty" : \(presencePenalty),
76-
"responseMIMEType" : "\(responseMIMEType)",
76+
"responseMimeType" : "\(responseMIMEType)",
7777
"responseSchema" : {
7878
"items" : {
7979
"nullable" : false,
@@ -109,7 +109,7 @@ final class GenerationConfigTests: XCTestCase {
109109
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
110110
XCTAssertEqual(json, """
111111
{
112-
"responseMIMEType" : "\(mimeType)",
112+
"responseMimeType" : "\(mimeType)",
113113
"responseSchema" : {
114114
"nullable" : false,
115115
"properties" : {

FirebaseVertexAI/Tests/Unit/PartTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ final class PartTests: XCTestCase {
132132
XCTAssertEqual(json, """
133133
{
134134
"fileData" : {
135-
"fileURI" : "\(fileURI)",
135+
"fileUri" : "\(fileURI)",
136136
"mimeType" : "\(mimeType)"
137137
}
138138
}

0 commit comments

Comments
 (0)