Skip to content

Commit 0d50ac3

Browse files
committed
[Vertex AI] Add ImagenSafetySettings type and param (#14237)
1 parent 2cd3110 commit 0d50ac3

File tree

7 files changed

+260
-22
lines changed

7 files changed

+260
-22
lines changed

FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
struct ImageGenerationParameters {
1717
let sampleCount: Int?
1818
let storageURI: String?
19-
let seed: Int32?
2019
let negativePrompt: String?
2120
let aspectRatio: String?
2221
let safetyFilterLevel: String?
@@ -36,7 +35,6 @@ extension ImageGenerationParameters: Encodable {
3635
enum CodingKeys: String, CodingKey {
3736
case sampleCount
3837
case storageURI = "storageUri"
39-
case seed
4038
case negativePrompt
4139
case aspectRatio
4240
case safetyFilterLevel = "safetySetting"
@@ -50,7 +48,6 @@ extension ImageGenerationParameters: Encodable {
5048
var container = encoder.container(keyedBy: CodingKeys.self)
5149
try container.encodeIfPresent(sampleCount, forKey: .sampleCount)
5250
try container.encodeIfPresent(storageURI, forKey: .storageURI)
53-
try container.encodeIfPresent(seed, forKey: .seed)
5451
try container.encodeIfPresent(negativePrompt, forKey: .negativePrompt)
5552
try container.encodeIfPresent(aspectRatio, forKey: .aspectRatio)
5653
try container.encodeIfPresent(safetyFilterLevel, forKey: .safetyFilterLevel)

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@ public final class ImagenModel {
2424
/// The backing service responsible for sending and receiving model requests to the backend.
2525
let generativeAIService: GenerativeAIService
2626

27+
let safetySettings: ImagenSafetySettings?
28+
2729
/// Configuration parameters for sending requests to the backend.
2830
let requestOptions: RequestOptions
2931

3032
init(name: String,
3133
projectID: String,
3234
apiKey: String,
35+
safetySettings: ImagenSafetySettings?,
3336
requestOptions: RequestOptions,
3437
appCheck: AppCheckInterop?,
3538
auth: AuthInterop?,
@@ -42,6 +45,7 @@ public final class ImagenModel {
4245
auth: auth,
4346
urlSession: urlSession
4447
)
48+
self.safetySettings = safetySettings
4549
self.requestOptions = requestOptions
4650
}
4751

@@ -50,7 +54,11 @@ public final class ImagenModel {
5054
-> ImageGenerationResponse<ImagenInlineDataImage> {
5155
return try await generateImages(
5256
prompt: prompt,
53-
parameters: imageGenerationParameters(storageURI: nil, generationConfig: generationConfig)
57+
parameters: ImagenModel.imageGenerationParameters(
58+
storageURI: nil,
59+
generationConfig: generationConfig,
60+
safetySettings: safetySettings
61+
)
5462
)
5563
}
5664

@@ -59,9 +67,10 @@ public final class ImagenModel {
5967
-> ImageGenerationResponse<ImagenFileDataImage> {
6068
return try await generateImages(
6169
prompt: prompt,
62-
parameters: imageGenerationParameters(
70+
parameters: ImagenModel.imageGenerationParameters(
6371
storageURI: storageURI,
64-
generationConfig: generationConfig
72+
generationConfig: generationConfig,
73+
safetySettings: safetySettings
6574
)
6675
)
6776
}
@@ -79,26 +88,25 @@ public final class ImagenModel {
7988
return try await generativeAIService.loadRequest(request: request)
8089
}
8190

82-
func imageGenerationParameters(storageURI: String?,
83-
generationConfig: ImagenGenerationConfig? = nil)
91+
static func imageGenerationParameters(storageURI: String?,
92+
generationConfig: ImagenGenerationConfig?,
93+
safetySettings: ImagenSafetySettings?)
8494
-> ImageGenerationParameters {
85-
// TODO(#14221): Add support for configuring remaining parameters.
8695
return ImageGenerationParameters(
8796
sampleCount: generationConfig?.numberOfImages ?? 1,
8897
storageURI: storageURI,
89-
seed: nil,
9098
negativePrompt: generationConfig?.negativePrompt,
9199
aspectRatio: generationConfig?.aspectRatio?.rawValue,
92-
safetyFilterLevel: nil,
93-
personGeneration: nil,
100+
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
101+
personGeneration: safetySettings?.personGeneration?.rawValue,
94102
outputOptions: generationConfig?.imageFormat.map {
95103
ImageGenerationOutputOptions(
96104
mimeType: $0.mimeType,
97105
compressionQuality: $0.compressionQuality
98106
)
99107
},
100108
addWatermark: generationConfig?.addWatermark,
101-
includeResponsibleAIFilterReason: true
109+
includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true
102110
)
103111
}
104112
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
public struct ImagenSafetySettings {
19+
let safetyFilterLevel: SafetyFilterLevel?
20+
let includeFilterReason: Bool?
21+
let personGeneration: PersonGeneration?
22+
23+
public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil,
24+
personGeneration: PersonGeneration? = nil) {
25+
self.safetyFilterLevel = safetyFilterLevel
26+
self.includeFilterReason = includeFilterReason
27+
self.personGeneration = personGeneration
28+
}
29+
}
30+
31+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
32+
public extension ImagenSafetySettings {
33+
struct SafetyFilterLevel: ProtoEnum {
34+
enum Kind: String {
35+
case blockLowAndAbove = "block_low_and_above"
36+
case blockMediumAndAbove = "block_medium_and_above"
37+
case blockOnlyHigh = "block_only_high"
38+
case blockNone = "block_none"
39+
}
40+
41+
public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove)
42+
public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove)
43+
public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh)
44+
public static let blockNone = SafetyFilterLevel(kind: .blockNone)
45+
46+
let rawValue: String
47+
}
48+
}
49+
50+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
51+
public extension ImagenSafetySettings {
52+
struct PersonGeneration: ProtoEnum {
53+
enum Kind: String {
54+
case blockAll = "dont_allow"
55+
case allowAdult = "allow_adult"
56+
case allowAll = "allow_all"
57+
}
58+
59+
public static let blockAll = PersonGeneration(kind: .blockAll)
60+
public static let allowAdult = PersonGeneration(kind: .allowAdult)
61+
public static let allowAll = PersonGeneration(kind: .allowAll)
62+
63+
let rawValue: String
64+
}
65+
}

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ public class VertexAI {
104104
)
105105
}
106106

107-
public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions())
108-
-> ImagenModel {
107+
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
108+
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
109109
return ImagenModel(
110110
name: modelResourceName(modelName: modelName),
111111
projectID: projectID,
112112
apiKey: apiKey,
113+
safetySettings: safetySettings,
113114
requestOptions: requestOptions,
114115
appCheck: appCheck,
115116
auth: auth

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ final class IntegrationTests: XCTestCase {
6262
systemInstruction: systemInstruction
6363
)
6464
imagenModel = vertex.imagenModel(
65-
modelName: "imagen-3.0-fast-generate-001"
65+
modelName: "imagen-3.0-fast-generate-001",
66+
safetySettings: ImagenSafetySettings(
67+
safetyFilterLevel: .blockLowAndAbove,
68+
personGeneration: .blockAll
69+
)
6670
)
6771

6872
storage = Storage.storage()

0 commit comments

Comments
 (0)