Skip to content

Commit 2cd3110

Browse files
committed
[Vertex AI] Add ImagenGenerationConfig to generateImages() (#14234)
1 parent 36d608e commit 2cd3110

File tree

5 files changed

+137
-15
lines changed

5 files changed

+137
-15
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
public struct ImagenAspectRatio {
19+
public static let square1x1 = ImagenAspectRatio(kind: .square1x1)
20+
21+
public static let portrait9x16 = ImagenAspectRatio(kind: .portrait9x16)
22+
23+
public static let landscape16x9 = ImagenAspectRatio(kind: .landscape16x9)
24+
25+
public static let portrait3x4 = ImagenAspectRatio(kind: .portrait3x4)
26+
27+
public static let landscape4x3 = ImagenAspectRatio(kind: .landscape4x3)
28+
29+
let rawValue: String
30+
}
31+
32+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
33+
extension ImagenAspectRatio: ProtoEnum {
34+
enum Kind: String {
35+
case square1x1 = "1:1"
36+
case portrait9x16 = "9:16"
37+
case landscape16x9 = "16:9"
38+
case portrait3x4 = "3:4"
39+
case landscape4x3 = "4:3"
40+
}
41+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
16+
public struct ImagenGenerationConfig {
17+
public var numberOfImages: Int?
18+
public var negativePrompt: String?
19+
public var aspectRatio: ImagenAspectRatio?
20+
public var imageFormat: ImagenImageFormat?
21+
public var addWatermark: Bool?
22+
23+
public init(numberOfImages: Int? = nil,
24+
negativePrompt: String? = nil,
25+
aspectRatio: ImagenAspectRatio? = nil,
26+
imageFormat: ImagenImageFormat? = nil,
27+
addWatermark: Bool? = nil) {
28+
self.numberOfImages = numberOfImages
29+
self.negativePrompt = negativePrompt
30+
self.aspectRatio = aspectRatio
31+
self.imageFormat = imageFormat
32+
self.addWatermark = addWatermark
33+
}
34+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
public struct ImagenImageFormat {
19+
let mimeType: String
20+
let compressionQuality: Int?
21+
22+
public static func png() -> ImagenImageFormat {
23+
return ImagenImageFormat(mimeType: "image/png", compressionQuality: nil)
24+
}
25+
26+
public static func jpeg(compressionQuality: Int? = nil) -> ImagenImageFormat {
27+
return ImagenImageFormat(mimeType: "image/jpeg", compressionQuality: compressionQuality)
28+
}
29+
}

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,24 @@ public final class ImagenModel {
4545
self.requestOptions = requestOptions
4646
}
4747

48-
public func generateImages(prompt: String) async throws
48+
public func generateImages(prompt: String,
49+
generationConfig: ImagenGenerationConfig? = nil) async throws
4950
-> ImageGenerationResponse<ImagenInlineDataImage> {
5051
return try await generateImages(
5152
prompt: prompt,
52-
parameters: imageGenerationParameters(storageURI: nil)
53+
parameters: imageGenerationParameters(storageURI: nil, generationConfig: generationConfig)
5354
)
5455
}
5556

56-
public func generateImages(prompt: String, storageURI: String) async throws
57+
public func generateImages(prompt: String, storageURI: String,
58+
generationConfig: ImagenGenerationConfig? = nil) async throws
5759
-> ImageGenerationResponse<ImagenFileDataImage> {
5860
return try await generateImages(
5961
prompt: prompt,
60-
parameters: imageGenerationParameters(storageURI: storageURI)
62+
parameters: imageGenerationParameters(
63+
storageURI: storageURI,
64+
generationConfig: generationConfig
65+
)
6166
)
6267
}
6368

@@ -74,18 +79,25 @@ public final class ImagenModel {
7479
return try await generativeAIService.loadRequest(request: request)
7580
}
7681

77-
func imageGenerationParameters(storageURI: String?) -> ImageGenerationParameters {
78-
// TODO(#14221): Add support for configuring these parameters.
82+
func imageGenerationParameters(storageURI: String?,
83+
generationConfig: ImagenGenerationConfig? = nil)
84+
-> ImageGenerationParameters {
85+
// TODO(#14221): Add support for configuring remaining parameters.
7986
return ImageGenerationParameters(
80-
sampleCount: 1,
87+
sampleCount: generationConfig?.numberOfImages ?? 1,
8188
storageURI: storageURI,
8289
seed: nil,
83-
negativePrompt: nil,
84-
aspectRatio: nil,
90+
negativePrompt: generationConfig?.negativePrompt,
91+
aspectRatio: generationConfig?.aspectRatio?.rawValue,
8592
safetyFilterLevel: nil,
8693
personGeneration: nil,
87-
outputOptions: nil,
88-
addWatermark: nil,
94+
outputOptions: generationConfig?.imageFormat.map {
95+
ImageGenerationOutputOptions(
96+
mimeType: $0.mimeType,
97+
compressionQuality: $0.compressionQuality
98+
)
99+
},
100+
addWatermark: generationConfig?.addWatermark,
89101
includeResponsibleAIFilterReason: true
90102
)
91103
}

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,28 @@ final class IntegrationTests: XCTestCase {
249249
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
250250
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
251251
"""
252+
var generationConfig = ImagenGenerationConfig()
253+
generationConfig.aspectRatio = .landscape16x9
254+
generationConfig.imageFormat = .jpeg(compressionQuality: 70)
252255

253-
let response = try await imagenModel.generateImages(prompt: imagePrompt)
256+
let response = try await imagenModel.generateImages(
257+
prompt: imagePrompt,
258+
generationConfig: generationConfig
259+
)
254260

255261
XCTAssertNil(response.raiFilteredReason)
256262
XCTAssertEqual(response.images.count, 1)
257263
let image = try XCTUnwrap(response.images.first)
258-
XCTAssertEqual(image.mimeType, "image/png")
264+
XCTAssertEqual(image.mimeType, "image/jpeg")
259265
XCTAssertGreaterThan(image.data.count, 0)
260266
let imagenImage = image.imagenImage
261267
XCTAssertEqual(imagenImage.mimeType, image.mimeType)
262268
XCTAssertEqual(imagenImage.bytesBase64Encoded, image.data.base64EncodedString())
263269
XCTAssertNil(imagenImage.gcsURI)
264270
#if canImport(UIKit)
265271
let uiImage = try XCTUnwrap(UIImage(data: image.data))
266-
XCTAssertEqual(uiImage.size.width, 1024.0)
267-
XCTAssertEqual(uiImage.size.height, 1024.0)
272+
XCTAssertEqual(uiImage.size.width, 1408.0)
273+
XCTAssertEqual(uiImage.size.height, 768.0)
268274
#endif
269275
}
270276
}

0 commit comments

Comments
 (0)