Skip to content

Commit 7bd6887

Browse files
authored
[Vertex AI] Use struct instead of enum for HarmCategory (#13728)
1 parent 289a1b9 commit 7bd6887

File tree

4 files changed

+68
-24
lines changed

4 files changed

+68
-24
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
as input. (#13767)
3737
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
3838
the label `parts: `. (#13832)
39+
- [changed] **Breaking Change**: `HarmCategory` is now a struct instead of an
40+
enum type and the `unknown` case has been removed; in a `switch` statement,
41+
use the `default:` case to cover unknown or unhandled categories. (#13728)
3942
- [changed] The default request timeout is now 180 seconds instead of the
4043
platform-default value of 60 seconds for a `URLRequest`; this timeout may
4144
still be customized in `RequestOptions`. (#13722)

FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ private extension HarmCategory {
2424
case .harassment: "Harassment"
2525
case .hateSpeech: "Hate speech"
2626
case .sexuallyExplicit: "Sexually explicit"
27-
case .unknown: "Unknown"
27+
case .civicIntegrity: "Civic integrity"
28+
default:
29+
"Unknown HarmCategory: \(rawValue)"
2830
}
2931
}
3032
}

FirebaseVertexAI/Sources/Safety.swift

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,65 @@ public struct SafetySetting {
9797
}
9898

9999
/// Categories describing the potential harm a piece of content may pose.
100-
public enum HarmCategory: String, Sendable {
101-
/// Unknown. A new server value that isn't recognized by the SDK.
102-
case unknown = "HARM_CATEGORY_UNKNOWN"
100+
public struct HarmCategory: Sendable, Equatable, Hashable {
101+
enum Kind: String {
102+
case harassment = "HARM_CATEGORY_HARASSMENT"
103+
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
104+
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
105+
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
106+
case civicIntegrity = "HARM_CATEGORY_CIVIC_INTEGRITY"
107+
}
103108

104109
/// Harassment content.
105-
case harassment = "HARM_CATEGORY_HARASSMENT"
110+
public static var harassment: HarmCategory {
111+
return self.init(kind: .harassment)
112+
}
106113

107114
/// Negative or harmful comments targeting identity and/or protected attributes.
108-
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
115+
public static var hateSpeech: HarmCategory {
116+
return self.init(kind: .hateSpeech)
117+
}
109118

110119
/// Contains references to sexual acts or other lewd content.
111-
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
120+
public static var sexuallyExplicit: HarmCategory {
121+
return self.init(kind: .sexuallyExplicit)
122+
}
112123

113124
/// Promotes or enables access to harmful goods, services, or activities.
114-
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
125+
public static var dangerousContent: HarmCategory {
126+
return self.init(kind: .dangerousContent)
127+
}
128+
129+
/// Content that may be used to harm civic integrity.
130+
public static var civicIntegrity: HarmCategory {
131+
return self.init(kind: .civicIntegrity)
132+
}
133+
134+
/// Returns the raw string representation of the `HarmCategory` value.
135+
///
136+
/// > Note: This value directly corresponds to the values in the
137+
/// > [REST API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/HarmCategory).
138+
public let rawValue: String
139+
140+
init(kind: Kind) {
141+
rawValue = kind.rawValue
142+
}
143+
144+
init(rawValue: String) {
145+
if Kind(rawValue: rawValue) == nil {
146+
VertexLog.error(
147+
code: .generateContentResponseUnrecognizedHarmCategory,
148+
"""
149+
Unrecognized HarmCategory with value "\(rawValue)":
150+
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
151+
release notes at https://firebase.google.com/support/release-notes/ios
152+
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
153+
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
154+
"""
155+
)
156+
}
157+
self.rawValue = rawValue
158+
}
115159
}
116160

117161
// MARK: - Codable Conformances
@@ -139,17 +183,8 @@ extension SafetyRating: Decodable {}
139183
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
140184
extension HarmCategory: Codable {
141185
public init(from decoder: Decoder) throws {
142-
let value = try decoder.singleValueContainer().decode(String.self)
143-
guard let decodedCategory = HarmCategory(rawValue: value) else {
144-
VertexLog.error(
145-
code: .generateContentResponseUnrecognizedHarmCategory,
146-
"Unrecognized HarmCategory with value \"\(value)\"."
147-
)
148-
self = .unknown
149-
return
150-
}
151-
152-
self = decodedCategory
186+
let rawValue = try decoder.singleValueContainer().decode(String.self)
187+
self = HarmCategory(rawValue: rawValue)
153188
}
154189
}
155190

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ final class GenerativeModelTests: XCTestCase {
163163
let expectedSafetyRatings = [
164164
SafetyRating(category: .harassment, probability: .medium),
165165
SafetyRating(category: .dangerousContent, probability: .unknown),
166-
SafetyRating(category: .unknown, probability: .high),
166+
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
167167
]
168168
MockURLProtocol
169169
.requestHandler = try httpRequestHandler(
@@ -972,18 +972,22 @@ final class GenerativeModelTests: XCTestCase {
972972
forResource: "streaming-success-unknown-safety-enum",
973973
withExtension: "txt"
974974
)
975+
let unknownSafetyRating = SafetyRating(
976+
category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
977+
probability: .unknown
978+
)
975979

976-
var hadUnknown = false
980+
var foundUnknownSafetyRating = false
977981
let stream = try model.generateContentStream("Hi")
978982
for try await content in stream {
979983
XCTAssertNotNil(content.text)
980984
if let ratings = content.candidates.first?.safetyRatings,
981-
ratings.contains(where: { $0.category == .unknown }) {
982-
hadUnknown = true
985+
ratings.contains(where: { $0 == unknownSafetyRating }) {
986+
foundUnknownSafetyRating = true
983987
}
984988
}
985989

986-
XCTAssertTrue(hadUnknown)
990+
XCTAssertTrue(foundUnknownSafetyRating)
987991
}
988992

989993
func testGenerateContentStream_successWithCitations() async throws {

0 commit comments

Comments
 (0)