Skip to content

Commit 06ef96f

Browse files
committed
[Vertex AI] Add ImageGenerationParameters for input to predict call (#14208)
1 parent 122d876 commit 06ef96f

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
18+
struct ImageGenerationOutputOptions {
19+
let mimeType: String
20+
let compressionQuality: Int?
21+
}
22+
23+
// MARK: - Codable Conformance
24+
25+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
26+
extension ImageGenerationOutputOptions: Encodable {}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
16+
struct ImageGenerationParameters {
17+
let sampleCount: Int?
18+
let storageURI: String?
19+
let seed: Int32?
20+
let negativePrompt: String?
21+
let aspectRatio: String?
22+
let safetyFilterLevel: String?
23+
let personGeneration: String?
24+
let outputOptions: ImageGenerationOutputOptions?
25+
let addWatermark: Bool?
26+
let includeResponsibleAIFilterReason: Bool?
27+
}
28+
29+
// MARK: - Codable Conformance
30+
31+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
32+
extension ImageGenerationParameters: Encodable {
33+
enum CodingKeys: String, CodingKey {
34+
case sampleCount
35+
case storageURI = "storageUri"
36+
case seed
37+
case negativePrompt
38+
case aspectRatio
39+
case safetyFilterLevel = "safetySetting"
40+
case personGeneration
41+
case outputOptions
42+
case addWatermark
43+
case includeResponsibleAIFilterReason = "includeRaiReason"
44+
}
45+
46+
func encode(to encoder: any Encoder) throws {
47+
var container = encoder.container(keyedBy: CodingKeys.self)
48+
try container.encodeIfPresent(sampleCount, forKey: .sampleCount)
49+
try container.encodeIfPresent(storageURI, forKey: .storageURI)
50+
try container.encodeIfPresent(seed, forKey: .seed)
51+
try container.encodeIfPresent(negativePrompt, forKey: .negativePrompt)
52+
try container.encodeIfPresent(aspectRatio, forKey: .aspectRatio)
53+
try container.encodeIfPresent(safetyFilterLevel, forKey: .safetyFilterLevel)
54+
try container.encodeIfPresent(personGeneration, forKey: .personGeneration)
55+
try container.encodeIfPresent(outputOptions, forKey: .outputOptions)
56+
try container.encodeIfPresent(addWatermark, forKey: .addWatermark)
57+
try container.encodeIfPresent(
58+
includeResponsibleAIFilterReason,
59+
forKey: .includeResponsibleAIFilterReason
60+
)
61+
}
62+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import FirebaseVertexAI
18+
19+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20+
final class ImageGenerationOutputOptionsTests: XCTestCase {
21+
let encoder = JSONEncoder()
22+
23+
override func setUp() {
24+
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
25+
}
26+
27+
// MARK: - Encoding Tests
28+
29+
func testEncodeOutputOptions_jpeg_defaultCompressionQuality() throws {
30+
let mimeType = "image/jpeg"
31+
let options = ImageGenerationOutputOptions(mimeType: mimeType, compressionQuality: nil)
32+
33+
let jsonData = try encoder.encode(options)
34+
35+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
36+
XCTAssertEqual(json, """
37+
{
38+
"mimeType" : "\(mimeType)"
39+
}
40+
""")
41+
}
42+
43+
func testEncodeOutputOptions_jpeg_customCompressionQuality() throws {
44+
let mimeType = "image/jpeg"
45+
let quality = 50
46+
let options = ImageGenerationOutputOptions(mimeType: mimeType, compressionQuality: quality)
47+
48+
let jsonData = try encoder.encode(options)
49+
50+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
51+
XCTAssertEqual(json, """
52+
{
53+
"compressionQuality" : \(quality),
54+
"mimeType" : "\(mimeType)"
55+
}
56+
""")
57+
}
58+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
17+
@testable import FirebaseVertexAI
18+
19+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20+
final class ImageGenerationParametersTests: XCTestCase {
21+
let encoder = JSONEncoder()
22+
23+
override func setUp() {
24+
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
25+
}
26+
27+
// MARK: - Encoding Tests
28+
29+
func testEncodeParameters_allSpecified() throws {
30+
let sampleCount = 4
31+
let storageURI = "gs://bucket/folder"
32+
let seed: Int32 = 1_076_107_968
33+
let negativePrompt = "test-negative-prompt"
34+
let aspectRatio = "16:9"
35+
let safetyFilterLevel = "block_low_and_above"
36+
let personGeneration = "allow_adult"
37+
let mimeType = "image/png"
38+
let outputOptions = ImageGenerationOutputOptions(mimeType: mimeType, compressionQuality: nil)
39+
let addWatermark = false
40+
let includeRAIReason = true
41+
let parameters = ImageGenerationParameters(
42+
sampleCount: sampleCount,
43+
storageURI: storageURI,
44+
seed: seed,
45+
negativePrompt: negativePrompt,
46+
aspectRatio: aspectRatio,
47+
safetyFilterLevel: safetyFilterLevel,
48+
personGeneration: personGeneration,
49+
outputOptions: outputOptions,
50+
addWatermark: addWatermark,
51+
includeResponsibleAIFilterReason: includeRAIReason
52+
)
53+
54+
let jsonData = try encoder.encode(parameters)
55+
56+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
57+
XCTAssertEqual(json, """
58+
{
59+
"addWatermark" : \(addWatermark),
60+
"aspectRatio" : "\(aspectRatio)",
61+
"includeRaiReason" : \(includeRAIReason),
62+
"negativePrompt" : "\(negativePrompt)",
63+
"outputOptions" : {
64+
"mimeType" : "\(mimeType)"
65+
},
66+
"personGeneration" : "\(personGeneration)",
67+
"safetySetting" : "\(safetyFilterLevel)",
68+
"sampleCount" : \(sampleCount),
69+
"seed" : \(seed),
70+
"storageUri" : "\(storageURI)"
71+
}
72+
""")
73+
}
74+
75+
func testEncodeParameters_someSpecified() throws {
76+
let sampleCount = 2
77+
let aspectRatio = "3:4"
78+
let safetyFilterLevel = "block_medium_and_above"
79+
let addWatermark = true
80+
let parameters = ImageGenerationParameters(
81+
sampleCount: sampleCount,
82+
storageURI: nil,
83+
seed: nil,
84+
negativePrompt: nil,
85+
aspectRatio: aspectRatio,
86+
safetyFilterLevel: safetyFilterLevel,
87+
personGeneration: nil,
88+
outputOptions: nil,
89+
addWatermark: addWatermark,
90+
includeResponsibleAIFilterReason: nil
91+
)
92+
93+
let jsonData = try encoder.encode(parameters)
94+
95+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
96+
XCTAssertEqual(json, """
97+
{
98+
"addWatermark" : \(addWatermark),
99+
"aspectRatio" : "\(aspectRatio)",
100+
"safetySetting" : "\(safetyFilterLevel)",
101+
"sampleCount" : \(sampleCount)
102+
}
103+
""")
104+
}
105+
106+
func testEncodeParameters_noneSpecified() throws {
107+
let parameters = ImageGenerationParameters(
108+
sampleCount: nil,
109+
storageURI: nil,
110+
seed: nil,
111+
negativePrompt: nil,
112+
aspectRatio: nil,
113+
safetyFilterLevel: nil,
114+
personGeneration: nil,
115+
outputOptions: nil,
116+
addWatermark: nil,
117+
includeResponsibleAIFilterReason: nil
118+
)
119+
120+
let jsonData = try encoder.encode(parameters)
121+
122+
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
123+
XCTAssertEqual(json, """
124+
{
125+
126+
}
127+
""")
128+
}
129+
}

0 commit comments

Comments
 (0)