Skip to content

Commit 1e63a55

Browse files
authored
[Vertex AI] Use struct instead of enum for HarmProbability (#13854)
1 parent 7107086 commit 1e63a55

File tree

4 files changed

+67
-31
lines changed

4 files changed

+67
-31
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
as input. (#13767)
3737
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
3838
the label `parts: `. (#13832)
39-
- [changed] **Breaking Change**: `HarmCategory` is now a struct instead of an
40-
enum type and the `unknown` case has been removed; in a `switch` statement,
41-
use the `default:` case to cover unknown or unhandled categories. (#13728)
39+
- [changed] **Breaking Change**: `HarmCategory` and `HarmProbability` are now
40+
structs instead of enums types and the `unknown` cases have been removed; in a
41+
`switch` statement, use the `default:` case to cover unknown or unhandled
42+
categories or probabilities. (#13728, #13854)
4243
- [changed] The default request timeout is now 180 seconds instead of the
4344
platform-default value of 60 seconds for a `URLRequest`; this timeout may
4445
still be customized in `RequestOptions`. (#13722)

FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ private extension HarmCategory {
2525
case .hateSpeech: "Hate speech"
2626
case .sexuallyExplicit: "Sexually explicit"
2727
case .civicIntegrity: "Civic integrity"
28-
default:
29-
"Unknown HarmCategory: \(rawValue)"
28+
default: "Unknown HarmCategory: \(rawValue)"
3029
}
3130
}
3231
}
@@ -39,7 +38,7 @@ private extension SafetyRating.HarmProbability {
3938
case .low: "Low"
4039
case .medium: "Medium"
4140
case .negligible: "Negligible"
42-
case .unknown: "Unknown"
41+
default: "Unknown HarmProbability: \(rawValue)"
4342
}
4443
}
4544
}

FirebaseVertexAI/Sources/Safety.swift

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,66 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
3838
self.probability = probability
3939
}
4040

41-
/// The probability that a given model output falls under a harmful content category. This does
42-
/// not indicate the severity of harm for a piece of content.
43-
public enum HarmProbability: String, Sendable {
44-
/// Unknown. A new server value that isn't recognized by the SDK.
45-
case unknown = "UNKNOWN"
41+
/// The probability that a given model output falls under a harmful content category.
42+
///
43+
/// > Note: This does not indicate the severity of harm for a piece of content.
44+
public struct HarmProbability: Sendable, Equatable, Hashable {
45+
enum Kind: String {
46+
case negligible = "NEGLIGIBLE"
47+
case low = "LOW"
48+
case medium = "MEDIUM"
49+
case high = "HIGH"
50+
}
4651

47-
/// The probability is zero or close to zero. For benign content, the probability across all
48-
/// categories will be this value.
49-
case negligible = "NEGLIGIBLE"
52+
/// The probability is zero or close to zero.
53+
///
54+
/// For benign content, the probability across all categories will be this value.
55+
public static var negligible: HarmProbability {
56+
return self.init(kind: .negligible)
57+
}
5058

5159
/// The probability is small but non-zero.
52-
case low = "LOW"
60+
public static var low: HarmProbability {
61+
return self.init(kind: .low)
62+
}
5363

5464
/// The probability is moderate.
55-
case medium = "MEDIUM"
65+
public static var medium: HarmProbability {
66+
return self.init(kind: .medium)
67+
}
68+
69+
/// The probability is high.
70+
///
71+
/// The content described is very likely harmful.
72+
public static var high: HarmProbability {
73+
return self.init(kind: .high)
74+
}
75+
76+
/// Returns the raw string representation of the `HarmProbability` value.
77+
///
78+
/// > Note: This value directly corresponds to the values in the [REST
79+
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#SafetyRating).
80+
public let rawValue: String
5681

57-
/// The probability is high. The content described is very likely harmful.
58-
case high = "HIGH"
82+
init(kind: Kind) {
83+
rawValue = kind.rawValue
84+
}
85+
86+
init(rawValue: String) {
87+
if Kind(rawValue: rawValue) == nil {
88+
VertexLog.error(
89+
code: .generateContentResponseUnrecognizedHarmProbability,
90+
"""
91+
Unrecognized HarmProbability with value "\(rawValue)":
92+
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
93+
release notes at https://firebase.google.com/support/release-notes/ios
94+
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
95+
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
96+
"""
97+
)
98+
}
99+
self.rawValue = rawValue
100+
}
59101
}
60102
}
61103

@@ -163,17 +205,8 @@ public struct HarmCategory: Sendable, Equatable, Hashable {
163205
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
164206
extension SafetyRating.HarmProbability: Decodable {
165207
public init(from decoder: Decoder) throws {
166-
let value = try decoder.singleValueContainer().decode(String.self)
167-
guard let decodedProbability = SafetyRating.HarmProbability(rawValue: value) else {
168-
VertexLog.error(
169-
code: .generateContentResponseUnrecognizedHarmProbability,
170-
"Unrecognized HarmProbability with value \"\(value)\"."
171-
)
172-
self = .unknown
173-
return
174-
}
175-
176-
self = decodedProbability
208+
let rawValue = try decoder.singleValueContainer().decode(String.self)
209+
self = SafetyRating.HarmProbability(rawValue: rawValue)
177210
}
178211
}
179212

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ final class GenerativeModelTests: XCTestCase {
162162
func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
163163
let expectedSafetyRatings = [
164164
SafetyRating(category: .harassment, probability: .medium),
165-
SafetyRating(category: .dangerousContent, probability: .unknown),
165+
SafetyRating(
166+
category: .dangerousContent,
167+
probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY")
168+
),
166169
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
167170
]
168171
MockURLProtocol
@@ -974,7 +977,7 @@ final class GenerativeModelTests: XCTestCase {
974977
)
975978
let unknownSafetyRating = SafetyRating(
976979
category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
977-
probability: .unknown
980+
probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM")
978981
)
979982

980983
var foundUnknownSafetyRating = false

0 commit comments

Comments
 (0)