Skip to content

Commit 4b263b6

Browse files
authored
[Vertex AI] Refactor BlockReason as a struct and add new values (#13861)
1 parent 27cffd9 commit 4b263b6

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
filter. (#13863)
5353
- [added] Added new `FinishReason` values `.blocklist`, `.prohibitedContent`,
5454
`.spii` and `.malformedFunctionCall` that may be reported. (#13860)
55+
- [added] Added new `BlockReason` values `.blocklist` and `.prohibitedContent`
56+
that may be reported when a prompt is blocked. (#13861)
5557

5658
# 11.3.0
5759
- [added] Added `Decodable` conformance for `FunctionResponse`. (#13606)

FirebaseVertexAI/Sources/GenerateContentResponse.swift

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,43 @@ public struct FinishReason: DecodableProtoEnum, Hashable, Sendable {
221221
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
222222
public struct PromptFeedback: Sendable {
223223
/// A type describing possible reasons to block a prompt.
224-
public enum BlockReason: String, Sendable {
225-
/// The block reason is unknown.
226-
case unknown = "UNKNOWN"
224+
public struct BlockReason: DecodableProtoEnum, Hashable, Sendable {
225+
enum Kind: String {
226+
case safety = "SAFETY"
227+
case other = "OTHER"
228+
case blocklist = "BLOCKLIST"
229+
case prohibitedContent = "PROHIBITED_CONTENT"
230+
}
227231

228232
/// The prompt was blocked because it was deemed unsafe.
229-
case safety = "SAFETY"
233+
public static var safety: BlockReason {
234+
return self.init(kind: .safety)
235+
}
230236

231237
/// All other block reasons.
232-
case other = "OTHER"
238+
public static var other: BlockReason {
239+
return self.init(kind: .other)
240+
}
241+
242+
/// The prompt was blocked because it contained terms from the terminology blocklist.
243+
public static var blocklist: BlockReason {
244+
return self.init(kind: .blocklist)
245+
}
246+
247+
/// The prompt was blocked due to prohibited content.
248+
public static var prohibitedContent: BlockReason {
249+
return self.init(kind: .prohibitedContent)
250+
}
251+
252+
/// Returns the raw string representation of the `BlockReason` value.
253+
///
254+
/// > Note: This value directly corresponds to the values in the [REST
255+
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#BlockedReason).
256+
public let rawValue: String
257+
258+
var unrecognizedValueMessageCode: VertexLog.MessageCode {
259+
.generateContentResponseUnrecognizedBlockReason
260+
}
233261
}
234262

235263
/// The reason a prompt was blocked, if it was blocked.
@@ -383,23 +411,6 @@ extension Citation: Decodable {
383411
}
384412
}
385413

386-
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
387-
extension PromptFeedback.BlockReason: Decodable {
388-
public init(from decoder: Decoder) throws {
389-
let value = try decoder.singleValueContainer().decode(String.self)
390-
guard let decodedBlockReason = PromptFeedback.BlockReason(rawValue: value) else {
391-
VertexLog.error(
392-
code: .generateContentResponseUnrecognizedBlockReason,
393-
"Unrecognized BlockReason with value \"\(value)\"."
394-
)
395-
self = .unknown
396-
return
397-
}
398-
399-
self = decodedBlockReason
400-
}
401-
}
402-
403414
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
404415
extension PromptFeedback: Decodable {
405416
enum CodingKeys: CodingKey {

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,14 @@ final class GenerativeModelTests: XCTestCase {
627627
forResource: "unary-failure-unknown-enum-prompt-blocked",
628628
withExtension: "json"
629629
)
630+
let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")
630631

631632
do {
632633
_ = try await model.generateContent(testPrompt)
633634
XCTFail("Should throw")
634635
} catch let GenerateContentError.promptBlocked(response) {
635636
let promptFeedback = try XCTUnwrap(response.promptFeedback)
636-
XCTAssertEqual(promptFeedback.blockReason, .unknown)
637+
XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
637638
} catch {
638639
XCTFail("Should throw a promptBlocked")
639640
}

0 commit comments

Comments
 (0)