Skip to content

Commit 20c9413

Browse files
authored
[Vertex AI] Refactor FinishReason as a struct and add new values (#13860)
1 parent 492e488 commit 20c9413

File tree

3 files changed

+68
-33
lines changed

3 files changed

+68
-33
lines changed

FirebaseVertexAI/CHANGELOG.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
as input. (#13767)
3737
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
3838
the label `parts: `. (#13832)
39-
- [changed] **Breaking Change**: `HarmCategory` and `HarmProbability` are now
40-
structs instead of enums types and the `unknown` cases have been removed; in a
41-
`switch` statement, use the `default:` case to cover unknown or unhandled
42-
categories or probabilities. (#13728, #13854)
39+
- [changed] **Breaking Change**: `HarmCategory`, `HarmProbability`, and
40+
`FinishReason` are now structs instead of enums types and the `unknown` cases
41+
have been removed; in a `switch` statement, use the `default:` case to cover
42+
unknown or unhandled values. (#13728, #13854, #13860)
4343
- [changed] The default request timeout is now 180 seconds instead of the
4444
platform-default value of 60 seconds for a `URLRequest`; this timeout may
4545
still be customized in `RequestOptions`. (#13722)

FirebaseVertexAI/Sources/GenerateContentResponse.swift

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -145,26 +145,76 @@ public struct Citation: Sendable {
145145

146146
/// A value enumerating possible reasons for a model to terminate a content generation request.
147147
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
148-
public enum FinishReason: String, Sendable {
149-
/// The finish reason is unknown.
150-
case unknown = "FINISH_REASON_UNKNOWN"
148+
public struct FinishReason: DecodableProtoEnum, Hashable, Sendable {
149+
enum Kind: String {
150+
case stop = "STOP"
151+
case maxTokens = "MAX_TOKENS"
152+
case safety = "SAFETY"
153+
case recitation = "RECITATION"
154+
case other = "OTHER"
155+
case blocklist = "BLOCKLIST"
156+
case prohibitedContent = "PROHIBITED_CONTENT"
157+
case spii = "SPII"
158+
case malformedFunctionCall = "MALFORMED_FUNCTION_CALL"
159+
}
151160

152161
/// Natural stop point of the model or provided stop sequence.
153-
case stop = "STOP"
162+
public static var stop: FinishReason {
163+
return self.init(kind: .stop)
164+
}
154165

155166
/// The maximum number of tokens as specified in the request was reached.
156-
case maxTokens = "MAX_TOKENS"
167+
public static var maxTokens: FinishReason {
168+
return self.init(kind: .maxTokens)
169+
}
157170

158171
/// The token generation was stopped because the response was flagged for safety reasons.
159-
/// NOTE: When streaming, the Candidate.content will be empty if content filters blocked the
160-
/// output.
161-
case safety = "SAFETY"
172+
///
173+
/// > NOTE: When streaming, the ``CandidateResponse/content`` will be empty if content filters
174+
/// > blocked the output.
175+
public static var safety: FinishReason {
176+
return self.init(kind: .safety)
177+
}
162178

163179
/// The token generation was stopped because the response was flagged for unauthorized citations.
164-
case recitation = "RECITATION"
180+
public static var recitation: FinishReason {
181+
return self.init(kind: .recitation)
182+
}
165183

166184
/// All other reasons that stopped token generation.
167-
case other = "OTHER"
185+
public static var other: FinishReason {
186+
return self.init(kind: .other)
187+
}
188+
189+
/// Token generation was stopped because the response contained forbidden terms.
190+
public static var blocklist: FinishReason {
191+
return self.init(kind: .blocklist)
192+
}
193+
194+
/// Token generation was stopped because the response contained potentially prohibited content.
195+
public static var prohibitedContent: FinishReason {
196+
return self.init(kind: .prohibitedContent)
197+
}
198+
199+
/// Token generation was stopped because of Sensitive Personally Identifiable Information (SPII).
200+
public static var spii: FinishReason {
201+
return self.init(kind: .spii)
202+
}
203+
204+
/// Token generation was stopped because the function call generated by the model was invalid.
205+
public static var malformedFunctionCall: FinishReason {
206+
return self.init(kind: .malformedFunctionCall)
207+
}
208+
209+
/// Returns the raw string representation of the `FinishReason` value.
210+
///
211+
/// > Note: This value directly corresponds to the values in the [REST
212+
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#FinishReason).
213+
public let rawValue: String
214+
215+
var unrecognizedValueMessageCode: VertexLog.MessageCode {
216+
.generateContentResponseUnrecognizedFinishReason
217+
}
168218
}
169219

170220
/// A metadata struct containing any feedback the model had on the prompt it was provided.
@@ -333,23 +383,6 @@ extension Citation: Decodable {
333383
}
334384
}
335385

336-
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
337-
extension FinishReason: Decodable {
338-
public init(from decoder: Decoder) throws {
339-
let value = try decoder.singleValueContainer().decode(String.self)
340-
guard let decodedFinishReason = FinishReason(rawValue: value) else {
341-
VertexLog.error(
342-
code: .generateContentResponseUnrecognizedFinishReason,
343-
"Unrecognized FinishReason with value \"\(value)\"."
344-
)
345-
self = .unknown
346-
return
347-
}
348-
349-
self = decodedFinishReason
350-
}
351-
}
352-
353386
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
354387
extension PromptFeedback.BlockReason: Decodable {
355388
public init(from decoder: Decoder) throws {

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,13 @@ final class GenerativeModelTests: XCTestCase {
608608
forResource: "unary-failure-unknown-enum-finish-reason",
609609
withExtension: "json"
610610
)
611+
let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")
611612

612613
do {
613614
_ = try await model.generateContent(testPrompt)
614615
XCTFail("Should throw")
615616
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
616-
XCTAssertEqual(reason, .unknown)
617+
XCTAssertEqual(reason, unknownFinishReason)
617618
XCTAssertEqual(response.text, "Some text")
618619
} catch {
619620
XCTFail("Should throw a responseStoppedEarly")
@@ -921,14 +922,15 @@ final class GenerativeModelTests: XCTestCase {
921922
forResource: "streaming-failure-unknown-finish-enum",
922923
withExtension: "txt"
923924
)
925+
let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")
924926

925927
let stream = try model.generateContentStream("Hi")
926928
do {
927929
for try await content in stream {
928930
XCTAssertNotNil(content.text)
929931
}
930932
} catch let GenerateContentError.responseStoppedEarly(reason, _) {
931-
XCTAssertEqual(reason, .unknown)
933+
XCTAssertEqual(reason, unknownFinishReason)
932934
return
933935
}
934936

0 commit comments

Comments
 (0)