Skip to content

Commit e1d1ad9

Browse files
authored
[Vertex AI] Add blockReasonMessage to PromptFeedback (#13891)
1 parent 0ab96c6 commit e1d1ad9

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
`.spii` and `.malformedFunctionCall` that may be reported. (#13860)
7171
- [added] Added new `BlockReason` values `.blocklist` and `.prohibitedContent`
7272
that may be reported when a prompt is blocked. (#13861)
73+
- [added] Added the `PromptFeedback` property `blockReasonMessage` that *may* be
74+
provided alongside the `blockReason`. (#13891)
7375

7476
# 11.3.0
7577
- [added] Added `Decodable` conformance for `FunctionResponse`. (#13606)

FirebaseVertexAI/Sources/GenerateContentResponse.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,17 @@ public struct PromptFeedback: Sendable {
235235
/// The reason a prompt was blocked, if it was blocked.
236236
public let blockReason: BlockReason?
237237

238+
/// A human-readable description of the ``blockReason``.
239+
public let blockReasonMessage: String?
240+
238241
/// The safety ratings of the prompt.
239242
public let safetyRatings: [SafetyRating]
240243

241244
/// Initializer for SwiftUI previews or tests.
242-
public init(blockReason: BlockReason?, safetyRatings: [SafetyRating]) {
245+
public init(blockReason: BlockReason?, blockReasonMessage: String? = nil,
246+
safetyRatings: [SafetyRating]) {
243247
self.blockReason = blockReason
248+
self.blockReasonMessage = blockReasonMessage
244249
self.safetyRatings = safetyRatings
245250
}
246251
}
@@ -387,6 +392,7 @@ extension Citation: Decodable {
387392
extension PromptFeedback: Decodable {
388393
enum CodingKeys: CodingKey {
389394
case blockReason
395+
case blockReasonMessage
390396
case safetyRatings
391397
}
392398

@@ -396,6 +402,7 @@ extension PromptFeedback: Decodable {
396402
PromptFeedback.BlockReason.self,
397403
forKey: .blockReason
398404
)
405+
blockReasonMessage = try container.decodeIfPresent(String.self, forKey: .blockReasonMessage)
399406
if let safetyRatings = try container.decodeIfPresent(
400407
[SafetyRating].self,
401408
forKey: .safetyRatings

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,29 @@ final class GenerativeModelTests: XCTestCase {
619619
XCTFail("Should throw")
620620
} catch let GenerateContentError.promptBlocked(response) {
621621
XCTAssertNil(response.text)
622+
let promptFeedback = try XCTUnwrap(response.promptFeedback)
623+
XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
624+
XCTAssertNil(promptFeedback.blockReasonMessage)
625+
} catch {
626+
XCTFail("Should throw a promptBlocked")
627+
}
628+
}
629+
630+
func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws {
631+
MockURLProtocol
632+
.requestHandler = try httpRequestHandler(
633+
forResource: "unary-failure-prompt-blocked-safety-with-message",
634+
withExtension: "json"
635+
)
636+
637+
do {
638+
_ = try await model.generateContent(testPrompt)
639+
XCTFail("Should throw")
640+
} catch let GenerateContentError.promptBlocked(response) {
641+
XCTAssertNil(response.text)
642+
let promptFeedback = try XCTUnwrap(response.promptFeedback)
643+
XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety)
644+
XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
622645
} catch {
623646
XCTFail("Should throw a promptBlocked")
624647
}
@@ -909,7 +932,31 @@ final class GenerativeModelTests: XCTestCase {
909932
XCTFail("Content shouldn't be shown, this shouldn't happen.")
910933
}
911934
} catch let GenerateContentError.promptBlocked(response) {
912-
XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
935+
let promptFeedback = try XCTUnwrap(response.promptFeedback)
936+
XCTAssertEqual(promptFeedback.blockReason, .safety)
937+
XCTAssertNil(promptFeedback.blockReasonMessage)
938+
return
939+
}
940+
941+
XCTFail("Should have caught an error.")
942+
}
943+
944+
func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws {
945+
MockURLProtocol
946+
.requestHandler = try httpRequestHandler(
947+
forResource: "streaming-failure-prompt-blocked-safety-with-message",
948+
withExtension: "txt"
949+
)
950+
951+
do {
952+
let stream = try model.generateContentStream("Hi")
953+
for try await _ in stream {
954+
XCTFail("Content shouldn't be shown, this shouldn't happen.")
955+
}
956+
} catch let GenerateContentError.promptBlocked(response) {
957+
let promptFeedback = try XCTUnwrap(response.promptFeedback)
958+
XCTAssertEqual(promptFeedback.blockReason, .safety)
959+
XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons")
913960
return
914961
}
915962

0 commit comments

Comments
 (0)