Skip to content

Commit d7e960c

Browse files
committed
[Vertex AI] Add ImageGenerationRequest for Imagen (#14225)
1 parent 3a04d77 commit d7e960c

File tree

6 files changed

+206
-0
lines changed

6 files changed

+206
-0
lines changed

FirebaseVertexAI/Sources/GenerativeAIRequest.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,6 @@ public struct RequestOptions {
4141
self.timeout = timeout
4242
}
4343
}
44+
45+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
46+
extension RequestOptions: Equatable {}

FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ struct ImageGenerationInstance {
1717
let prompt: String
1818
}
1919

20+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
21+
extension ImageGenerationInstance: Equatable {}
22+
2023
// MARK: - Codable Conformance
2124

2225
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationOutputOptions.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ struct ImageGenerationOutputOptions {
2020
let compressionQuality: Int?
2121
}
2222

23+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
24+
extension ImageGenerationOutputOptions: Equatable {}
25+
2326
// MARK: - Codable Conformance
2427

2528
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)

FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ struct ImageGenerationParameters {
2626
let includeResponsibleAIFilterReason: Bool?
2727
}
2828

29+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
30+
extension ImageGenerationParameters: Equatable {}
31+
2932
// MARK: - Codable Conformance
3033

3134
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> {
19+
let model: String
20+
let options: RequestOptions
21+
let instances: [ImageGenerationInstance]
22+
let parameters: ImageGenerationParameters
23+
24+
init(model: String, options: RequestOptions, instances: [ImageGenerationInstance],
25+
parameters: ImageGenerationParameters) {
26+
self.model = model
27+
self.options = options
28+
self.instances = instances
29+
self.parameters = parameters
30+
}
31+
}
32+
33+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
34+
extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
35+
typealias Response = ImageGenerationResponse<ImageType>
36+
37+
var url: URL {
38+
return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")!
39+
}
40+
}
41+
42+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
43+
extension ImageGenerationRequest: Encodable {
44+
enum CodingKeys: CodingKey {
45+
case instances
46+
case parameters
47+
}
48+
49+
func encode(to encoder: any Encoder) throws {
50+
var container = encoder.container(keyedBy: CodingKeys.self)
51+
try container.encode(instances, forKey: .instances)
52+
try container.encode(parameters, forKey: .parameters)
53+
}
54+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import FirebaseVertexAI
18+
19+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20+
final class ImageGenerationRequestTests: XCTestCase {
21+
let encoder = JSONEncoder()
22+
let requestOptions = RequestOptions(timeout: 30.0)
23+
let modelName = "test-model-name"
24+
let sampleCount = 4
25+
let aspectRatio = "16:9"
26+
let safetyFilterLevel = "block_low_and_above"
27+
let includeResponsibleAIFilterReason = true
28+
lazy var parameters = ImageGenerationParameters(
29+
sampleCount: sampleCount,
30+
storageURI: nil,
31+
seed: nil,
32+
negativePrompt: nil,
33+
aspectRatio: aspectRatio,
34+
safetyFilterLevel: safetyFilterLevel,
35+
personGeneration: nil,
36+
outputOptions: nil,
37+
addWatermark: nil,
38+
includeResponsibleAIFilterReason: includeResponsibleAIFilterReason
39+
)
40+
41+
let instance = ImageGenerationInstance(prompt: "test-prompt")
42+
43+
override func setUp() {
44+
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
45+
}
46+
47+
func testInitializeRequest_inlineDataImage() throws {
48+
let request = ImageGenerationRequest<ImagenInlineDataImage>(
49+
model: modelName,
50+
options: requestOptions,
51+
instances: [instance],
52+
parameters: parameters
53+
)
54+
55+
XCTAssertEqual(request.model, modelName)
56+
XCTAssertEqual(request.options, requestOptions)
57+
XCTAssertEqual(request.instances, [instance])
58+
XCTAssertEqual(request.parameters, parameters)
59+
XCTAssertEqual(
60+
request.url,
61+
URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict")
62+
)
63+
}
64+
65+
func testInitializeRequest_fileDataImage() throws {
66+
let request = ImageGenerationRequest<ImagenFileDataImage>(
67+
model: modelName,
68+
options: requestOptions,
69+
instances: [instance],
70+
parameters: parameters
71+
)
72+
73+
XCTAssertEqual(request.model, modelName)
74+
XCTAssertEqual(request.options, requestOptions)
75+
XCTAssertEqual(request.instances, [instance])
76+
XCTAssertEqual(request.parameters, parameters)
77+
XCTAssertEqual(
78+
request.url,
79+
URL(string: "\(Constants.baseURL)/\(requestOptions.apiVersion)/\(modelName):predict")
80+
)
81+
}
82+
83+
// MARK: - Encoding Tests
84+
85+
func testEncodeRequest_inlineDataImage() throws {
86+
let request = ImageGenerationRequest<ImagenInlineDataImage>(
87+
model: modelName,
88+
options: RequestOptions(),
89+
instances: [instance],
90+
parameters: parameters
91+
)
92+
93+
let jsonData = try encoder.encode(request)
94+
95+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
96+
XCTAssertEqual(json, """
97+
{
98+
"instances" : [
99+
{
100+
"prompt" : "\(instance.prompt)"
101+
}
102+
],
103+
"parameters" : {
104+
"aspectRatio" : "\(aspectRatio)",
105+
"includeRaiReason" : \(includeResponsibleAIFilterReason),
106+
"safetySetting" : "\(safetyFilterLevel)",
107+
"sampleCount" : \(sampleCount)
108+
}
109+
}
110+
""")
111+
}
112+
113+
func testEncodeRequest_fileDataImage() throws {
114+
let request = ImageGenerationRequest<ImagenFileDataImage>(
115+
model: modelName,
116+
options: RequestOptions(),
117+
instances: [instance],
118+
parameters: parameters
119+
)
120+
121+
let jsonData = try encoder.encode(request)
122+
123+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
124+
XCTAssertEqual(json, """
125+
{
126+
"instances" : [
127+
{
128+
"prompt" : "\(instance.prompt)"
129+
}
130+
],
131+
"parameters" : {
132+
"aspectRatio" : "\(aspectRatio)",
133+
"includeRaiReason" : \(includeResponsibleAIFilterReason),
134+
"safetySetting" : "\(safetyFilterLevel)",
135+
"sampleCount" : \(sampleCount)
136+
}
137+
}
138+
""")
139+
}
140+
}

0 commit comments

Comments
 (0)