diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index f2facff2411..346ef9a70bf 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -58,6 +58,9 @@ `totalBillableCharacters` counts, where applicable. (#13813) - [added] Added a new `HarmCategory` `.civicIntegrity` for filtering content that may be used to harm civic integrity. (#13728) +- [added] Added `probabilityScore`, `severity` and `severityScore` in + `SafetyRating` to provide more fine-grained detail on blocked responses. + (#13875) - [added] Added a new `HarmBlockThreshold` `.off`, which turns off the safety filter. (#13863) - [added] Added new `FinishReason` values `.blocklist`, `.prohibitedContent`, diff --git a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift index 279f02b81fc..236a1f7d4b0 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift @@ -168,10 +168,38 @@ struct ErrorDetailsView: View { Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo. """), safetyRatings: [ - SafetyRating(category: .dangerousContent, probability: .high), - SafetyRating(category: .harassment, probability: .low), - SafetyRating(category: .hateSpeech, probability: .low), - SafetyRating(category: .sexuallyExplicit, probability: .low), + SafetyRating( + category: .dangerousContent, + probability: .medium, + probabilityScore: 0.8, + severity: .medium, + severityScore: 0.9, + blocked: false + ), + SafetyRating( + category: .harassment, + probability: .low, + probabilityScore: 0.5, + severity: .low, + severityScore: 0.6, + blocked: false + ), + SafetyRating( + category: .hateSpeech, + probability: .low, + probabilityScore: 0.3, + severity: .medium, + severityScore: 0.2, + blocked: false + ), + SafetyRating( + category: .sexuallyExplicit, + probability: .low, + probabilityScore: 0.2, + severity: .negligible, + severityScore: 0.5, + blocked: false + ), ], finishReason: FinishReason.maxTokens, citationMetadata: nil), @@ -190,10 +218,38 @@ struct ErrorDetailsView: View { Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo. """), safetyRatings: [ - SafetyRating(category: .dangerousContent, probability: .high), - SafetyRating(category: .harassment, probability: .low), - SafetyRating(category: .hateSpeech, probability: .low), - SafetyRating(category: .sexuallyExplicit, probability: .low), + SafetyRating( + category: .dangerousContent, + probability: .low, + probabilityScore: 0.8, + severity: .medium, + severityScore: 0.9, + blocked: false + ), + SafetyRating( + category: .harassment, + probability: .low, + probabilityScore: 0.5, + severity: .low, + severityScore: 0.6, + blocked: false + ), + SafetyRating( + category: .hateSpeech, + probability: .low, + probabilityScore: 0.3, + severity: .medium, + severityScore: 0.2, + blocked: false + ), + SafetyRating( + category: .sexuallyExplicit, + probability: .low, + probabilityScore: 0.2, + severity: .negligible, + severityScore: 0.5, + blocked: false + ), ], finishReason: FinishReason.other, citationMetadata: nil), diff --git a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift index e43258557b4..3efce09a119 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift @@ -36,22 +36,54 @@ struct ErrorView: View { #Preview { NavigationView { let errorPromptBlocked = GenerateContentError.promptBlocked( - response: GenerateContentResponse(candidates: [ - CandidateResponse(content: ModelContent(role: "model", parts: [ - """ - A _hypothetical_ model response. - Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo. - """, - ]), - safetyRatings: [ - SafetyRating(category: .dangerousContent, probability: .high), - SafetyRating(category: .harassment, probability: .low), - SafetyRating(category: .hateSpeech, probability: .low), - SafetyRating(category: .sexuallyExplicit, probability: .low), - ], - finishReason: FinishReason.other, - citationMetadata: nil), - ]) + response: GenerateContentResponse( + candidates: [ + CandidateResponse( + content: ModelContent(role: "model", parts: [ + """ + A _hypothetical_ model response. + Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo. + """, + ]), + safetyRatings: [ + SafetyRating( + category: .dangerousContent, + probability: .high, + probabilityScore: 0.8, + severity: .medium, + severityScore: 0.9, + blocked: true + ), + SafetyRating( + category: .harassment, + probability: .low, + probabilityScore: 0.5, + severity: .low, + severityScore: 0.6, + blocked: false + ), + SafetyRating( + category: .hateSpeech, + probability: .low, + probabilityScore: 0.3, + severity: .medium, + severityScore: 0.2, + blocked: false + ), + SafetyRating( + category: .sexuallyExplicit, + probability: .low, + probabilityScore: 0.2, + severity: .negligible, + severityScore: 0.5, + blocked: false + ), + ], + finishReason: FinishReason.other, + citationMetadata: nil + ), + ] + ) ) List { MessageView(message: ChatMessage.samples[0]) diff --git a/FirebaseVertexAI/Sources/Safety.swift b/FirebaseVertexAI/Sources/Safety.swift index d810613aecb..2ff4fe85f1c 100644 --- a/FirebaseVertexAI/Sources/Safety.swift +++ b/FirebaseVertexAI/Sources/Safety.swift @@ -26,16 +26,50 @@ public struct SafetyRating: Equatable, Hashable, Sendable { /// The model-generated probability that the content falls under the specified harm ``category``. /// - /// See ``HarmProbability`` for a list of possible values. + /// See ``HarmProbability`` for a list of possible values. This is a discretized representation + /// of the ``probabilityScore``. /// /// > Important: This does not indicate the severity of harm for a piece of content. public let probability: HarmProbability + /// The confidence score that the response is associated with the corresponding harm ``category``. + /// + /// The probability safety score is a confidence score between 0.0 and 1.0, rounded to one decimal + /// place; it is discretized into a ``HarmProbability`` in ``probability``. See [probability + /// scores](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters#comparison_of_probability_scores_and_severity_scores) + /// in the Google Cloud documentation for more details. + public let probabilityScore: Float + + /// The severity reflects the magnitude of how harmful a model response might be. + /// + /// See ``HarmSeverity`` for a list of possible values. This is a discretized representation of + /// the ``severityScore``. + public let severity: HarmSeverity + + /// The severity score is the magnitude of how harmful a model response might be. + /// + /// The severity score ranges from 0.0 to 1.0, rounded to one decimal place; it is discretized + /// into a ``HarmSeverity`` in ``severity``. See [severity scores](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters#comparison_of_probability_scores_and_severity_scores) + /// in the Google Cloud documentation for more details. + public let severityScore: Float + + /// If true, the response was blocked. + public let blocked: Bool + /// Initializes a new `SafetyRating` instance with the given category and probability. /// Use this initializer for SwiftUI previews or tests. - public init(category: HarmCategory, probability: HarmProbability) { + public init(category: HarmCategory, + probability: HarmProbability, + probabilityScore: Float, + severity: HarmSeverity, + severityScore: Float, + blocked: Bool) { self.category = category self.probability = probability + self.probabilityScore = probabilityScore + self.severity = severity + self.severityScore = severityScore + self.blocked = blocked } /// The probability that a given model output falls under a harmful content category. @@ -74,6 +108,37 @@ public struct SafetyRating: Equatable, Hashable, Sendable { static let unrecognizedValueMessageCode = VertexLog.MessageCode.generateContentResponseUnrecognizedHarmProbability } + + /// The magnitude of how harmful a model response might be for the respective ``HarmCategory``. + public struct HarmSeverity: DecodableProtoEnum, Hashable, Sendable { + enum Kind: String { + case negligible = "HARM_SEVERITY_NEGLIGIBLE" + case low = "HARM_SEVERITY_LOW" + case medium = "HARM_SEVERITY_MEDIUM" + case high = "HARM_SEVERITY_HIGH" + } + + /// Negligible level of harm severity. + public static let negligible = HarmSeverity(kind: .negligible) + + /// Low level of harm severity. + public static let low = HarmSeverity(kind: .low) + + /// Medium level of harm severity. + public static let medium = HarmSeverity(kind: .medium) + + /// High level of harm severity. + public static let high = HarmSeverity(kind: .high) + + /// Returns the raw string representation of the `HarmSeverity` value. + /// + /// > Note: This value directly corresponds to the values in the [REST + /// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#HarmSeverity). + public let rawValue: String + + static let unrecognizedValueMessageCode = + VertexLog.MessageCode.generateContentResponseUnrecognizedHarmSeverity + } } /// A type used to specify a threshold for harmful content, beyond which the model will return a @@ -164,7 +229,31 @@ public struct HarmCategory: CodableProtoEnum, Hashable, Sendable { // MARK: - Codable Conformances @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension SafetyRating: Decodable {} +extension SafetyRating: Decodable { + enum CodingKeys: CodingKey { + case category + case probability + case probabilityScore + case severity + case severityScore + case blocked + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + category = try container.decode(HarmCategory.self, forKey: .category) + probability = try container.decode(HarmProbability.self, forKey: .probability) + + // The following 3 fields are only omitted in our test data. + probabilityScore = try container.decodeIfPresent(Float.self, forKey: .probabilityScore) ?? 0.0 + severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ?? + HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED") + severityScore = try container.decodeIfPresent(Float.self, forKey: .severityScore) ?? 0.0 + + // The blocked field is only included when true. + blocked = try container.decodeIfPresent(Bool.self, forKey: .blocked) ?? false + } +} @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension SafetySetting.HarmBlockThreshold: Encodable {} diff --git a/FirebaseVertexAI/Sources/VertexLog.swift b/FirebaseVertexAI/Sources/VertexLog.swift index bd400c200c2..7ffaf78f0fc 100644 --- a/FirebaseVertexAI/Sources/VertexLog.swift +++ b/FirebaseVertexAI/Sources/VertexLog.swift @@ -49,6 +49,7 @@ enum VertexLog { case generateContentResponseUnrecognizedBlockThreshold = 3004 case generateContentResponseUnrecognizedHarmProbability = 3005 case generateContentResponseUnrecognizedHarmCategory = 3006 + case generateContentResponseUnrecognizedHarmSeverity = 3007 // SDK State Errors case generateContentResponseNoCandidates = 4000 diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index de14552c251..5ffa94daf64 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -23,10 +23,38 @@ import XCTest final class GenerativeModelTests: XCTestCase { let testPrompt = "What sorts of questions can I ask you?" let safetyRatingsNegligible: [SafetyRating] = [ - .init(category: .sexuallyExplicit, probability: .negligible), - .init(category: .hateSpeech, probability: .negligible), - .init(category: .harassment, probability: .negligible), - .init(category: .dangerousContent, probability: .negligible), + .init( + category: .sexuallyExplicit, + probability: .negligible, + probabilityScore: 0.1431877, + severity: .negligible, + severityScore: 0.11027937, + blocked: false + ), + .init( + category: .hateSpeech, + probability: .negligible, + probabilityScore: 0.029035643, + severity: .negligible, + severityScore: 0.05613278, + blocked: false + ), + .init( + category: .harassment, + probability: .negligible, + probabilityScore: 0.087252244, + severity: .negligible, + severityScore: 0.04509957, + blocked: false + ), + .init( + category: .dangerousContent, + probability: .negligible, + probabilityScore: 0.2641685, + severity: .negligible, + severityScore: 0.082253955, + blocked: false + ), ].sorted() let testModelResourceName = "projects/test-project-id/locations/test-location/publishers/google/models/test-model" @@ -69,7 +97,7 @@ final class GenerativeModelTests: XCTestCase { let candidate = try XCTUnwrap(response.candidates.first) let finishReason = try XCTUnwrap(candidate.finishReason) XCTAssertEqual(finishReason, .stop) - XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible) + XCTAssertEqual(candidate.safetyRatings.count, 4) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) let partText = try XCTUnwrap(part as? TextPart).text @@ -148,7 +176,7 @@ final class GenerativeModelTests: XCTestCase { let candidate = try XCTUnwrap(response.candidates.first) let finishReason = try XCTUnwrap(candidate.finishReason) XCTAssertEqual(finishReason, .stop) - XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible) + XCTAssertEqual(candidate.safetyRatings.count, 4) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) let textPart = try XCTUnwrap(part as? TextPart) @@ -156,17 +184,35 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.text, textPart.text) let promptFeedback = try XCTUnwrap(response.promptFeedback) XCTAssertNil(promptFeedback.blockReason) - XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible) + XCTAssertEqual(promptFeedback.safetyRatings.count, 4) } func testGenerateContent_success_unknownEnum_safetyRatings() async throws { let expectedSafetyRatings = [ - SafetyRating(category: .harassment, probability: .medium), + SafetyRating( + category: .harassment, + probability: .medium, + probabilityScore: 0.0, + severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"), + severityScore: 0.0, + blocked: false + ), SafetyRating( category: .dangerousContent, - probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY") + probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"), + probabilityScore: 0.0, + severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"), + severityScore: 0.0, + blocked: false + ), + SafetyRating( + category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), + probability: .high, + probabilityScore: 0.0, + severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"), + severityScore: 0.0, + blocked: false ), - SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high), ] MockURLProtocol .requestHandler = try httpRequestHandler( @@ -839,8 +885,11 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } - } catch let GenerateContentError.responseStoppedEarly(reason, _) { + } catch let GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .safety) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.finishReason, reason) + XCTAssertTrue(candidate.safetyRatings.contains { $0.blocked }) return } @@ -930,7 +979,11 @@ final class GenerativeModelTests: XCTestCase { ) let unknownSafetyRating = SafetyRating( category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"), - probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM") + probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"), + probabilityScore: 0.0, + severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"), + severityScore: 0.0, + blocked: false ) var foundUnknownSafetyRating = false