Skip to content

Commit e1e9796

Browse files
committed
[Vertex AI] Add HarmSeverity enum and SafetyRating properties
1 parent 3eaa04d commit e1e9796

File tree

5 files changed

+243
-37
lines changed

5 files changed

+243
-37
lines changed

FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,38 @@ struct ErrorDetailsView: View {
168168
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
169169
"""),
170170
safetyRatings: [
171-
SafetyRating(category: .dangerousContent, probability: .high),
172-
SafetyRating(category: .harassment, probability: .low),
173-
SafetyRating(category: .hateSpeech, probability: .low),
174-
SafetyRating(category: .sexuallyExplicit, probability: .low),
171+
SafetyRating(
172+
category: .dangerousContent,
173+
probability: .medium,
174+
probabilityScore: 0.8,
175+
severity: .medium,
176+
severityScore: 0.9,
177+
blocked: false
178+
),
179+
SafetyRating(
180+
category: .harassment,
181+
probability: .low,
182+
probabilityScore: 0.5,
183+
severity: .low,
184+
severityScore: 0.6,
185+
blocked: false
186+
),
187+
SafetyRating(
188+
category: .hateSpeech,
189+
probability: .low,
190+
probabilityScore: 0.3,
191+
severity: .medium,
192+
severityScore: 0.2,
193+
blocked: false
194+
),
195+
SafetyRating(
196+
category: .sexuallyExplicit,
197+
probability: .low,
198+
probabilityScore: 0.2,
199+
severity: .negligible,
200+
severityScore: 0.5,
201+
blocked: false
202+
),
175203
],
176204
finishReason: FinishReason.maxTokens,
177205
citationMetadata: nil),
@@ -190,10 +218,38 @@ struct ErrorDetailsView: View {
190218
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
191219
"""),
192220
safetyRatings: [
193-
SafetyRating(category: .dangerousContent, probability: .high),
194-
SafetyRating(category: .harassment, probability: .low),
195-
SafetyRating(category: .hateSpeech, probability: .low),
196-
SafetyRating(category: .sexuallyExplicit, probability: .low),
221+
SafetyRating(
222+
category: .dangerousContent,
223+
probability: .low,
224+
probabilityScore: 0.8,
225+
severity: .medium,
226+
severityScore: 0.9,
227+
blocked: false
228+
),
229+
SafetyRating(
230+
category: .harassment,
231+
probability: .low,
232+
probabilityScore: 0.5,
233+
severity: .low,
234+
severityScore: 0.6,
235+
blocked: false
236+
),
237+
SafetyRating(
238+
category: .hateSpeech,
239+
probability: .low,
240+
probabilityScore: 0.3,
241+
severity: .medium,
242+
severityScore: 0.2,
243+
blocked: false
244+
),
245+
SafetyRating(
246+
category: .sexuallyExplicit,
247+
probability: .low,
248+
probabilityScore: 0.2,
249+
severity: .negligible,
250+
severityScore: 0.5,
251+
blocked: false
252+
),
197253
],
198254
finishReason: FinishReason.other,
199255
citationMetadata: nil),

FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,54 @@ struct ErrorView: View {
3636
#Preview {
3737
NavigationView {
3838
let errorPromptBlocked = GenerateContentError.promptBlocked(
39-
response: GenerateContentResponse(candidates: [
40-
CandidateResponse(content: ModelContent(role: "model", parts: [
41-
"""
42-
A _hypothetical_ model response.
43-
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
44-
""",
45-
]),
46-
safetyRatings: [
47-
SafetyRating(category: .dangerousContent, probability: .high),
48-
SafetyRating(category: .harassment, probability: .low),
49-
SafetyRating(category: .hateSpeech, probability: .low),
50-
SafetyRating(category: .sexuallyExplicit, probability: .low),
51-
],
52-
finishReason: FinishReason.other,
53-
citationMetadata: nil),
54-
])
39+
response: GenerateContentResponse(
40+
candidates: [
41+
CandidateResponse(
42+
content: ModelContent(role: "model", parts: [
43+
"""
44+
A _hypothetical_ model response.
45+
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
46+
""",
47+
]),
48+
safetyRatings: [
49+
SafetyRating(
50+
category: .dangerousContent,
51+
probability: .high,
52+
probabilityScore: 0.8,
53+
severity: .medium,
54+
severityScore: 0.9,
55+
blocked: true
56+
),
57+
SafetyRating(
58+
category: .harassment,
59+
probability: .low,
60+
probabilityScore: 0.5,
61+
severity: .low,
62+
severityScore: 0.6,
63+
blocked: false
64+
),
65+
SafetyRating(
66+
category: .hateSpeech,
67+
probability: .low,
68+
probabilityScore: 0.3,
69+
severity: .medium,
70+
severityScore: 0.2,
71+
blocked: false
72+
),
73+
SafetyRating(
74+
category: .sexuallyExplicit,
75+
probability: .low,
76+
probabilityScore: 0.2,
77+
severity: .negligible,
78+
severityScore: 0.5,
79+
blocked: false
80+
),
81+
],
82+
finishReason: FinishReason.other,
83+
citationMetadata: nil
84+
),
85+
]
86+
)
5587
)
5688
List {
5789
MessageView(message: ChatMessage.samples[0])

FirebaseVertexAI/Sources/Safety.swift

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,28 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
3131
/// > Important: This does not indicate the severity of harm for a piece of content.
3232
public let probability: HarmProbability
3333

34+
public let probabilityScore: Float
35+
36+
public let severity: HarmSeverity
37+
38+
public let severityScore: Float
39+
40+
public let blocked: Bool
41+
3442
/// Initializes a new `SafetyRating` instance with the given category and probability.
3543
/// Use this initializer for SwiftUI previews or tests.
36-
public init(category: HarmCategory, probability: HarmProbability) {
44+
public init(category: HarmCategory,
45+
probability: HarmProbability,
46+
probabilityScore: Float,
47+
severity: HarmSeverity,
48+
severityScore: Float,
49+
blocked: Bool) {
3750
self.category = category
3851
self.probability = probability
52+
self.probabilityScore = probabilityScore
53+
self.severity = severity
54+
self.severityScore = severityScore
55+
self.blocked = blocked
3956
}
4057

4158
/// The probability that a given model output falls under a harmful content category.
@@ -74,6 +91,32 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
7491
static let unrecognizedValueMessageCode =
7592
VertexLog.MessageCode.generateContentResponseUnrecognizedHarmProbability
7693
}
94+
95+
public struct HarmSeverity: DecodableProtoEnum, Hashable, Sendable {
96+
enum Kind: String {
97+
case negligible = "HARM_SEVERITY_NEGLIGIBLE"
98+
case low = "HARM_SEVERITY_LOW"
99+
case medium = "HARM_SEVERITY_MEDIUM"
100+
case high = "HARM_SEVERITY_HIGH"
101+
}
102+
103+
public static let negligible = HarmSeverity(kind: .negligible)
104+
105+
public static let low = HarmSeverity(kind: .low)
106+
107+
public static let medium = HarmSeverity(kind: .medium)
108+
109+
public static let high = HarmSeverity(kind: .high)
110+
111+
/// Returns the raw string representation of the `HarmSeverity` value.
112+
///
113+
/// > Note: This value directly corresponds to the values in the [REST
114+
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#HarmSeverity).
115+
public let rawValue: String
116+
117+
static let unrecognizedValueMessageCode =
118+
VertexLog.MessageCode.generateContentResponseUnrecognizedHarmSeverity
119+
}
77120
}
78121

79122
/// A type used to specify a threshold for harmful content, beyond which the model will return a
@@ -164,7 +207,31 @@ public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
164207
// MARK: - Codable Conformances
165208

166209
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
167-
extension SafetyRating: Decodable {}
210+
extension SafetyRating: Decodable {
211+
enum CodingKeys: CodingKey {
212+
case category
213+
case probability
214+
case probabilityScore
215+
case severity
216+
case severityScore
217+
case blocked
218+
}
219+
220+
public init(from decoder: any Decoder) throws {
221+
let container = try decoder.container(keyedBy: CodingKeys.self)
222+
category = try container.decode(HarmCategory.self, forKey: .category)
223+
probability = try container.decode(HarmProbability.self, forKey: .probability)
224+
225+
// The following 3 fields are only omitted in our test data.
226+
probabilityScore = try container.decodeIfPresent(Float.self, forKey: .probabilityScore) ?? 0.0
227+
severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ??
228+
HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED")
229+
severityScore = try container.decodeIfPresent(Float.self, forKey: .severityScore) ?? 0.0
230+
231+
// The blocked field is only included when true.
232+
blocked = try container.decodeIfPresent(Bool.self, forKey: .blocked) ?? false
233+
}
234+
}
168235

169236
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
170237
extension SafetySetting.HarmBlockThreshold: Encodable {}

FirebaseVertexAI/Sources/VertexLog.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum VertexLog {
4949
case generateContentResponseUnrecognizedBlockThreshold = 3004
5050
case generateContentResponseUnrecognizedHarmProbability = 3005
5151
case generateContentResponseUnrecognizedHarmCategory = 3006
52+
case generateContentResponseUnrecognizedHarmSeverity = 3007
5253

5354
// SDK State Errors
5455
case generateContentResponseNoCandidates = 4000

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,38 @@ import XCTest
2323
final class GenerativeModelTests: XCTestCase {
2424
let testPrompt = "What sorts of questions can I ask you?"
2525
let safetyRatingsNegligible: [SafetyRating] = [
26-
.init(category: .sexuallyExplicit, probability: .negligible),
27-
.init(category: .hateSpeech, probability: .negligible),
28-
.init(category: .harassment, probability: .negligible),
29-
.init(category: .dangerousContent, probability: .negligible),
26+
.init(
27+
category: .sexuallyExplicit,
28+
probability: .negligible,
29+
probabilityScore: 0.1431877,
30+
severity: .negligible,
31+
severityScore: 0.11027937,
32+
blocked: false
33+
),
34+
.init(
35+
category: .hateSpeech,
36+
probability: .negligible,
37+
probabilityScore: 0.029035643,
38+
severity: .negligible,
39+
severityScore: 0.05613278,
40+
blocked: false
41+
),
42+
.init(
43+
category: .harassment,
44+
probability: .negligible,
45+
probabilityScore: 0.087252244,
46+
severity: .negligible,
47+
severityScore: 0.04509957,
48+
blocked: false
49+
),
50+
.init(
51+
category: .dangerousContent,
52+
probability: .negligible,
53+
probabilityScore: 0.2641685,
54+
severity: .negligible,
55+
severityScore: 0.082253955,
56+
blocked: false
57+
),
3058
].sorted()
3159
let testModelResourceName =
3260
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
@@ -69,7 +97,7 @@ final class GenerativeModelTests: XCTestCase {
6997
let candidate = try XCTUnwrap(response.candidates.first)
7098
let finishReason = try XCTUnwrap(candidate.finishReason)
7199
XCTAssertEqual(finishReason, .stop)
72-
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
100+
XCTAssertEqual(candidate.safetyRatings.count, 4)
73101
XCTAssertEqual(candidate.content.parts.count, 1)
74102
let part = try XCTUnwrap(candidate.content.parts.first)
75103
let partText = try XCTUnwrap(part as? TextPart).text
@@ -148,25 +176,43 @@ final class GenerativeModelTests: XCTestCase {
148176
let candidate = try XCTUnwrap(response.candidates.first)
149177
let finishReason = try XCTUnwrap(candidate.finishReason)
150178
XCTAssertEqual(finishReason, .stop)
151-
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
179+
XCTAssertEqual(candidate.safetyRatings.count, 4)
152180
XCTAssertEqual(candidate.content.parts.count, 1)
153181
let part = try XCTUnwrap(candidate.content.parts.first)
154182
let textPart = try XCTUnwrap(part as? TextPart)
155183
XCTAssertTrue(textPart.text.hasPrefix("Google"))
156184
XCTAssertEqual(response.text, textPart.text)
157185
let promptFeedback = try XCTUnwrap(response.promptFeedback)
158186
XCTAssertNil(promptFeedback.blockReason)
159-
XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
187+
XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
160188
}
161189

162190
func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
163191
let expectedSafetyRatings = [
164-
SafetyRating(category: .harassment, probability: .medium),
192+
SafetyRating(
193+
category: .harassment,
194+
probability: .medium,
195+
probabilityScore: 0.0,
196+
severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
197+
severityScore: 0.0,
198+
blocked: false
199+
),
165200
SafetyRating(
166201
category: .dangerousContent,
167-
probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY")
202+
probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
203+
probabilityScore: 0.0,
204+
severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
205+
severityScore: 0.0,
206+
blocked: false
207+
),
208+
SafetyRating(
209+
category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
210+
probability: .high,
211+
probabilityScore: 0.0,
212+
severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
213+
severityScore: 0.0,
214+
blocked: false
168215
),
169-
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
170216
]
171217
MockURLProtocol
172218
.requestHandler = try httpRequestHandler(
@@ -930,7 +976,11 @@ final class GenerativeModelTests: XCTestCase {
930976
)
931977
let unknownSafetyRating = SafetyRating(
932978
category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
933-
probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM")
979+
probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
980+
probabilityScore: 0.0,
981+
severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
982+
severityScore: 0.0,
983+
blocked: false
934984
)
935985

936986
var foundUnknownSafetyRating = false

0 commit comments

Comments
 (0)