Skip to content

Commit 492e488

Browse files
authored
[Vertex AI] Extract common protobuf enum to struct decoding logic (#13859)
1 parent 0e3e20d commit 492e488

File tree

2 files changed

+92
-54
lines changed

2 files changed

+92
-54
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/// A type that can be decoded from a Protocol Buffer raw enum value.
16+
///
17+
/// Protobuf enums are represented as strings in JSON. A default `Decodable` implementation is
18+
/// provided when conforming to this type.
19+
protocol DecodableProtoEnum: Decodable {
20+
/// The type representing the valid values for the protobuf enum.
21+
///
22+
/// > Important: This type must conform to `RawRepresentable` with the `RawValue == String`.
23+
///
24+
/// This is typically a Swift enum, e.g.:
25+
/// ```
26+
/// enum Kind: String {
27+
/// case north = "WIND_DIRECTION_NORTH"
28+
/// case south = "WIND_DIRECTION_SOUTH"
29+
/// case east = "WIND_DIRECTION_EAST"
30+
/// case west = "WIND_DIRECTION_WEST"
31+
/// }
32+
/// ```
33+
associatedtype Kind: RawRepresentable<String>
34+
35+
/// Returns the ``VertexLog/MessageCode`` associated with unrecognized (unknown) enum values.
36+
var unrecognizedValueMessageCode: VertexLog.MessageCode { get }
37+
38+
/// Create a new instance of the specified type from a raw enum value.
39+
init(rawValue: String)
40+
41+
/// Creates a new instance from the ``Kind``'s raw value.
42+
///
43+
/// > Important: A default implementation is provided.
44+
init(kind: Kind)
45+
46+
/// Creates a new instance by decoding from the given decoder.
47+
///
48+
/// > Important: A default implementation is provided.
49+
init(from decoder: Decoder) throws
50+
}
51+
52+
/// Default `Decodable` implementation for types conforming to `DecodableProtoEnum`.
53+
extension DecodableProtoEnum {
54+
// Note: Initializer 'init(from:)' must be declared public because it matches a requirement in
55+
// public protocol 'Decodable'.
56+
public init(from decoder: Decoder) throws {
57+
let rawValue = try decoder.singleValueContainer().decode(String.self)
58+
59+
self = Self(rawValue: rawValue)
60+
61+
if Kind(rawValue: rawValue) == nil {
62+
VertexLog.error(
63+
code: unrecognizedValueMessageCode,
64+
"""
65+
Unrecognized \(Self.self) with value "\(rawValue)":
66+
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
67+
release notes at https://firebase.google.com/support/release-notes/ios
68+
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
69+
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
70+
"""
71+
)
72+
}
73+
}
74+
}
75+
76+
/// Default implementation of `init(kind: Kind)` for types conforming to `DecodableProtoEnum`.
77+
extension DecodableProtoEnum {
78+
init(kind: Kind) {
79+
self = Self(rawValue: kind.rawValue)
80+
}
81+
}
82+
83+
/// A type that can be decoded and encoded from a Protocol Buffer raw enum value.
84+
///
85+
/// See ``DecodableProtoEnum`` for more details.
86+
protocol CodableProtoEnum: DecodableProtoEnum, Encodable {}

FirebaseVertexAI/Sources/Safety.swift

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
4141
/// The probability that a given model output falls under a harmful content category.
4242
///
4343
/// > Note: This does not indicate the severity of harm for a piece of content.
44-
public struct HarmProbability: Sendable, Equatable, Hashable {
44+
public struct HarmProbability: DecodableProtoEnum, Hashable, Sendable {
4545
enum Kind: String {
4646
case negligible = "NEGLIGIBLE"
4747
case low = "LOW"
@@ -79,24 +79,8 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
7979
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#SafetyRating).
8080
public let rawValue: String
8181

82-
init(kind: Kind) {
83-
rawValue = kind.rawValue
84-
}
85-
86-
init(rawValue: String) {
87-
if Kind(rawValue: rawValue) == nil {
88-
VertexLog.error(
89-
code: .generateContentResponseUnrecognizedHarmProbability,
90-
"""
91-
Unrecognized HarmProbability with value "\(rawValue)":
92-
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
93-
release notes at https://firebase.google.com/support/release-notes/ios
94-
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
95-
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
96-
"""
97-
)
98-
}
99-
self.rawValue = rawValue
82+
var unrecognizedValueMessageCode: VertexLog.MessageCode {
83+
.generateContentResponseUnrecognizedHarmProbability
10084
}
10185
}
10286
}
@@ -139,7 +123,7 @@ public struct SafetySetting {
139123
}
140124

141125
/// Categories describing the potential harm a piece of content may pose.
142-
public struct HarmCategory: Sendable, Equatable, Hashable {
126+
public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
143127
enum Kind: String {
144128
case harassment = "HARM_CATEGORY_HARASSMENT"
145129
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
@@ -179,48 +163,16 @@ public struct HarmCategory: Sendable, Equatable, Hashable {
179163
/// > [REST API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/HarmCategory).
180164
public let rawValue: String
181165

182-
init(kind: Kind) {
183-
rawValue = kind.rawValue
184-
}
185-
186-
init(rawValue: String) {
187-
if Kind(rawValue: rawValue) == nil {
188-
VertexLog.error(
189-
code: .generateContentResponseUnrecognizedHarmCategory,
190-
"""
191-
Unrecognized HarmCategory with value "\(rawValue)":
192-
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
193-
release notes at https://firebase.google.com/support/release-notes/ios
194-
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
195-
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
196-
"""
197-
)
198-
}
199-
self.rawValue = rawValue
166+
var unrecognizedValueMessageCode: VertexLog.MessageCode {
167+
.generateContentResponseUnrecognizedHarmCategory
200168
}
201169
}
202170

203171
// MARK: - Codable Conformances
204172

205-
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
206-
extension SafetyRating.HarmProbability: Decodable {
207-
public init(from decoder: Decoder) throws {
208-
let rawValue = try decoder.singleValueContainer().decode(String.self)
209-
self = SafetyRating.HarmProbability(rawValue: rawValue)
210-
}
211-
}
212-
213173
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
214174
extension SafetyRating: Decodable {}
215175

216-
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
217-
extension HarmCategory: Codable {
218-
public init(from decoder: Decoder) throws {
219-
let rawValue = try decoder.singleValueContainer().decode(String.self)
220-
self = HarmCategory(rawValue: rawValue)
221-
}
222-
}
223-
224176
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
225177
extension SafetySetting.HarmBlockThreshold: Encodable {}
226178

0 commit comments

Comments
 (0)