Skip to content

Commit 1229559

Browse files
committed
[Vertex AI] Add ImagenSafetySettings
1 parent c5472fc commit 1229559

File tree

4 files changed

+80
-6
lines changed

4 files changed

+80
-6
lines changed

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@ public final class ImagenModel {
2424
/// The backing service responsible for sending and receiving model requests to the backend.
2525
let generativeAIService: GenerativeAIService
2626

27+
let safetySettings: ImagenSafetySettings?
28+
2729
/// Configuration parameters for sending requests to the backend.
2830
let requestOptions: RequestOptions
2931

3032
init(name: String,
3133
projectID: String,
3234
apiKey: String,
35+
safetySettings: ImagenSafetySettings?,
3336
requestOptions: RequestOptions,
3437
appCheck: AppCheckInterop?,
3538
auth: AuthInterop?,
@@ -42,6 +45,7 @@ public final class ImagenModel {
4245
auth: auth,
4346
urlSession: urlSession
4447
)
48+
self.safetySettings = safetySettings
4549
self.requestOptions = requestOptions
4650
}
4751

@@ -89,16 +93,16 @@ public final class ImagenModel {
8993
seed: nil,
9094
negativePrompt: generationConfig?.negativePrompt,
9195
aspectRatio: generationConfig?.aspectRatio?.rawValue,
92-
safetyFilterLevel: nil,
93-
personGeneration: nil,
96+
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
97+
personGeneration: safetySettings?.personGeneration?.rawValue,
9498
outputOptions: generationConfig?.imageFormat.map {
9599
ImageGenerationOutputOptions(
96100
mimeType: $0.mimeType,
97101
compressionQuality: $0.compressionQuality
98102
)
99103
},
100104
addWatermark: generationConfig?.addWatermark,
101-
includeResponsibleAIFilterReason: true
105+
includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true
102106
)
103107
}
104108
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
public struct ImagenSafetySettings {
19+
let safetyFilterLevel: SafetyFilterLevel?
20+
let includeFilterReason: Bool?
21+
let personGeneration: PersonGeneration?
22+
23+
public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil,
24+
personGeneration: PersonGeneration? = nil) {
25+
self.safetyFilterLevel = safetyFilterLevel
26+
self.includeFilterReason = includeFilterReason
27+
self.personGeneration = personGeneration
28+
}
29+
}
30+
31+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
32+
public extension ImagenSafetySettings {
33+
struct SafetyFilterLevel: ProtoEnum {
34+
enum Kind: String {
35+
case blockLowAndAbove = "block_low_and_above"
36+
case blockMediumAndAbove = "block_medium_and_above"
37+
case blockOnlyHigh = "block_only_high"
38+
case blockNone = "block_none"
39+
}
40+
41+
public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove)
42+
public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove)
43+
public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh)
44+
public static let blockNone = SafetyFilterLevel(kind: .blockNone)
45+
46+
let rawValue: String
47+
}
48+
}
49+
50+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
51+
public extension ImagenSafetySettings {
52+
struct PersonGeneration: ProtoEnum {
53+
enum Kind: String {
54+
case blockAll = "dont_allow"
55+
case allowAdult = "allow_adult"
56+
case allowAll = "allow_all"
57+
}
58+
59+
public static let blockAll = PersonGeneration(kind: .blockAll)
60+
public static let allowAdult = PersonGeneration(kind: .allowAdult)
61+
public static let allowAll = PersonGeneration(kind: .allowAll)
62+
63+
let rawValue: String
64+
}
65+
}

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ public class VertexAI {
104104
)
105105
}
106106

107-
public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions())
108-
-> ImagenModel {
107+
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
108+
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
109109
return ImagenModel(
110110
name: modelResourceName(modelName: modelName),
111111
projectID: projectID,
112112
apiKey: apiKey,
113+
safetySettings: safetySettings,
113114
requestOptions: requestOptions,
114115
appCheck: appCheck,
115116
auth: auth

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ final class IntegrationTests: XCTestCase {
6262
systemInstruction: systemInstruction
6363
)
6464
imagenModel = vertex.imagenModel(
65-
modelName: "imagen-3.0-fast-generate-001"
65+
modelName: "imagen-3.0-fast-generate-001",
66+
safetySettings: ImagenSafetySettings(
67+
safetyFilterLevel: .blockLowAndAbove,
68+
personGeneration: .blockAll
69+
)
6670
)
6771

6872
storage = Storage.storage()

0 commit comments

Comments
 (0)