Skip to content

Commit 5fabd1d

Browse files
committed
[Vertex AI] Add ImageGenerationResponse for decoding predict response
1 parent 78fe33c commit 5fabd1d

File tree

2 files changed

+500
-0
lines changed

2 files changed

+500
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
struct ImageGenerationResponse {
19+
let images: [Image]
20+
let raiFilteredReason: String?
21+
}
22+
23+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
24+
extension ImageGenerationResponse {
25+
struct Image: Equatable {
26+
let mimeType: String
27+
let bytesBase64Encoded: String?
28+
let gcsURI: String?
29+
}
30+
}
31+
32+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
33+
extension ImageGenerationResponse {
34+
struct RAIFilteredReason {
35+
let raiFilteredReason: String
36+
}
37+
}
38+
39+
// MARK: - Codable Conformances
40+
41+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
42+
extension ImageGenerationResponse.Image: Decodable {
43+
enum CodingKeys: String, CodingKey {
44+
case mimeType
45+
case bytesBase64Encoded
46+
case gcsURI = "gcsUri"
47+
}
48+
49+
init(from decoder: any Decoder) throws {
50+
let container = try decoder.container(keyedBy: CodingKeys.self)
51+
mimeType = try container.decode(String.self, forKey: .mimeType)
52+
bytesBase64Encoded = try container.decodeIfPresent(String.self, forKey: .bytesBase64Encoded)
53+
gcsURI = try container.decodeIfPresent(String.self, forKey: .gcsURI)
54+
guard bytesBase64Encoded != nil || gcsURI != nil else {
55+
throw DecodingError.dataCorrupted(
56+
DecodingError.Context(
57+
codingPath: [CodingKeys.bytesBase64Encoded, CodingKeys.gcsURI],
58+
debugDescription: """
59+
Expected one of \(CodingKeys.bytesBase64Encoded.rawValue) or \
60+
\(CodingKeys.gcsURI.rawValue); both are nil.
61+
"""
62+
)
63+
)
64+
}
65+
guard bytesBase64Encoded == nil || gcsURI == nil else {
66+
throw DecodingError.dataCorrupted(
67+
DecodingError.Context(
68+
codingPath: [CodingKeys.bytesBase64Encoded, CodingKeys.gcsURI],
69+
debugDescription: """
70+
Expected one of \(CodingKeys.bytesBase64Encoded.rawValue) or \
71+
\(CodingKeys.gcsURI.rawValue); both are specified.
72+
"""
73+
)
74+
)
75+
}
76+
}
77+
}
78+
79+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
80+
extension ImageGenerationResponse.RAIFilteredReason: Decodable {
81+
enum CodingKeys: CodingKey {
82+
case raiFilteredReason
83+
}
84+
}
85+
86+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
87+
extension ImageGenerationResponse: Decodable {
88+
enum CodingKeys: CodingKey {
89+
case predictions
90+
}
91+
92+
public init(from decoder: any Decoder) throws {
93+
let container = try decoder.container(keyedBy: CodingKeys.self)
94+
guard container.contains(.predictions) else {
95+
images = []
96+
raiFilteredReason = nil
97+
// TODO: Log warning if no predictions.
98+
return
99+
}
100+
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)
101+
102+
var images = [Image]()
103+
var raiFilteredReasons = [String]()
104+
while !predictionsContainer.isAtEnd {
105+
if let image = try? predictionsContainer.decode(Image.self) {
106+
images.append(image)
107+
} else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
108+
raiFilteredReasons.append(filterReason.raiFilteredReason)
109+
} else if let _ = try? predictionsContainer.decode(JSONObject.self) {
110+
// TODO: Log or throw unsupported prediction type
111+
} else {
112+
// This should never be thrown since JSONObject accepts any valid JSON.
113+
throw DecodingError.dataCorruptedError(
114+
in: predictionsContainer,
115+
debugDescription: "Failed to decode Prediction."
116+
)
117+
}
118+
}
119+
120+
self.images = images
121+
raiFilteredReason = raiFilteredReasons.first
122+
// TODO: Log if more than one RAI Filtered Reason; unexpected behaviour.
123+
}
124+
}

0 commit comments

Comments
 (0)