Skip to content

Commit d85fed2

Browse files
committed
Imagen editing
1 parent b9bf3ad commit d85fed2

11 files changed

+547
-7
lines changed

FirebaseAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,42 @@
1515
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
1616
struct ImageGenerationInstance {
1717
let prompt: String
18+
let referenceImages: [ImagenReferenceImage]?
19+
20+
init(prompt: String, referenceImages: [ImagenReferenceImage]? = nil) {
21+
self.prompt = prompt
22+
self.referenceImages = referenceImages
23+
}
1824
}
1925

2026
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
21-
extension ImageGenerationInstance: Equatable {}
27+
extension ImageGenerationInstance: Equatable {
28+
static func == (lhs: ImageGenerationInstance, rhs: ImageGenerationInstance) -> Bool {
29+
return lhs.prompt == rhs.prompt && lhs.referenceImages?.count == rhs.referenceImages?.count
30+
}
31+
}
2232

2333
// MARK: - Codable Conformance
2434

2535
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
26-
extension ImageGenerationInstance: Encodable {}
36+
extension ImageGenerationInstance: Encodable {
37+
enum CodingKeys: String, CodingKey {
38+
case prompt
39+
case referenceImages = "image"
40+
}
41+
42+
func encode(to encoder: any Encoder) throws {
43+
var container = encoder.container(keyedBy: CodingKeys.self)
44+
try container.encode(prompt, forKey: .prompt)
45+
if let referenceImages = referenceImages {
46+
var imagesContainer = container.nestedUnkeyedContainer(forKey: .referenceImages)
47+
for image in referenceImages {
48+
if let rawImage = image as? ImagenRawImage {
49+
try imagesContainer.encode(rawImage)
50+
} else if let mask = image as? ImagenMaskReference {
51+
try imagesContainer.encode(mask)
52+
}
53+
}
54+
}
55+
}
56+
}

FirebaseAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,32 @@ struct ImageGenerationParameters {
2323
let outputOptions: ImageGenerationOutputOptions?
2424
let addWatermark: Bool?
2525
let includeResponsibleAIFilterReason: Bool?
26+
let editMode: String?
27+
let editConfig: ImageEditingParameters?
28+
29+
init(sampleCount: Int?,
30+
storageURI: String?,
31+
negativePrompt: String?,
32+
aspectRatio: String?,
33+
safetyFilterLevel: String?,
34+
personGeneration: String?,
35+
outputOptions: ImageGenerationOutputOptions?,
36+
addWatermark: Bool?,
37+
includeResponsibleAIFilterReason: Bool?,
38+
editMode: String? = nil,
39+
editConfig: ImageEditingParameters? = nil) {
40+
self.sampleCount = sampleCount
41+
self.storageURI = storageURI
42+
self.negativePrompt = negativePrompt
43+
self.aspectRatio = aspectRatio
44+
self.safetyFilterLevel = safetyFilterLevel
45+
self.personGeneration = personGeneration
46+
self.outputOptions = outputOptions
47+
self.addWatermark = addWatermark
48+
self.includeResponsibleAIFilterReason = includeResponsibleAIFilterReason
49+
self.editMode = editMode
50+
self.editConfig = editConfig
51+
}
2652
}
2753

2854
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
@@ -42,6 +68,8 @@ extension ImageGenerationParameters: Encodable {
4268
case outputOptions
4369
case addWatermark
4470
case includeResponsibleAIFilterReason = "includeRaiReason"
71+
case editMode
72+
case editConfig
4573
}
4674

4775
func encode(to encoder: any Encoder) throws {
@@ -58,5 +86,12 @@ extension ImageGenerationParameters: Encodable {
5886
includeResponsibleAIFilterReason,
5987
forKey: .includeResponsibleAIFilterReason
6088
)
89+
try container.encodeIfPresent(editMode, forKey: .editMode)
90+
try container.encodeIfPresent(editConfig, forKey: .editConfig)
6191
}
6292
}
93+
94+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
95+
struct ImageEditingParameters: Codable, Equatable {
96+
let editSteps: Int?
97+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Represents the dimensions of an image.
18+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
19+
public struct Dimensions: Codable, Sendable {
20+
/// The width of the image in pixels.
21+
public let width: Int
22+
23+
/// The height of the image in pixels.
24+
public let height: Int
25+
26+
public init(width: Int, height: Int) {
27+
self.width = width
28+
self.height = height
29+
}
30+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// The editing method to use.
18+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
19+
public enum ImagenEditMode: String, Codable, Sendable {
20+
/// The model should use the prompt and reference images to generate a new image.
21+
case product = "product-image"
22+
23+
/// The model should generate a new background for the given image.
24+
case background = "background-refresh"
25+
26+
/// The model should replace the masked region of the image with new content.
27+
case inpaint = "inpainting"
28+
29+
/// The model should extend the image beyond its original borders.
30+
case outpaint = "outpainting"
31+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Configuration for editing an image with Imagen.
18+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
19+
public struct ImagenEditingConfig: Codable, Sendable {
20+
/// The editing method to use.
21+
public let editMode: ImagenEditMode
22+
23+
/// The number of steps to use for the editing process.
24+
public let editSteps: Int?
25+
26+
public init(editMode: ImagenEditMode, editSteps: Int? = nil) {
27+
self.editMode = editMode
28+
self.editSteps = editSteps
29+
}
30+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
/// Represents the placement of an image within a larger canvas.
18+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
19+
public enum ImagenImagePlacement: Sendable {
20+
/// The image is placed at the top left corner of the canvas.
21+
case topLeft
22+
23+
/// The image is placed at the top center of the canvas.
24+
case topCenter
25+
26+
/// The image is placed at the top right corner of the canvas.
27+
case topRight
28+
29+
/// The image is placed at the middle left of the canvas.
30+
case middleLeft
31+
32+
/// The image is placed in the center of the canvas.
33+
case center
34+
35+
/// The image is placed at the middle right of the canvas.
36+
case middleRight
37+
38+
/// The image is placed at the bottom left corner of the canvas.
39+
case bottomLeft
40+
41+
/// The image is placed at the bottom center of the canvas.
42+
case bottomCenter
43+
44+
/// The image is placed at the bottom right corner of the canvas.
45+
case bottomRight
46+
47+
/// The image is placed at a custom offset from the top left corner of the canvas.
48+
case custom(x: Int, y: Int)
49+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import CoreGraphics
16+
import Foundation
17+
import ImageIO
18+
19+
/// A reference to a mask for inpainting.
20+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
21+
public struct ImagenMaskReference: ImagenReferenceImage, Encodable {
22+
/// The mask data.
23+
public let data: Data
24+
25+
public init(data: Data) {
26+
self.data = data
27+
}
28+
29+
enum CodingKeys: String, CodingKey {
30+
case data = "bytesBase64Encoded"
31+
}
32+
33+
public func encode(to encoder: Encoder) throws {
34+
var container = encoder.container(keyedBy: CodingKeys.self)
35+
try container.encode(data.base64EncodedString(), forKey: .data)
36+
}
37+
38+
static func generateMaskAndPadForOutpainting(image: ImagenInlineImage,
39+
newDimensions: Dimensions,
40+
newPosition: ImagenImagePlacement) throws
41+
-> [ImagenReferenceImage] {
42+
guard let cgImage = CGImage.fromData(image.data) else {
43+
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not create image from data."])
44+
}
45+
46+
let originalWidth = cgImage.width
47+
let originalHeight = cgImage.height
48+
49+
guard newDimensions.width >= originalWidth, newDimensions.height >= originalHeight else {
50+
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "New dimensions must be larger than the original image."])
51+
}
52+
53+
let offsetX: Int
54+
let offsetY: Int
55+
56+
switch newPosition {
57+
case .topLeft:
58+
offsetX = 0
59+
offsetY = 0
60+
case .topCenter:
61+
offsetX = (newDimensions.width - originalWidth) / 2
62+
offsetY = 0
63+
case .topRight:
64+
offsetX = newDimensions.width - originalWidth
65+
offsetY = 0
66+
case .middleLeft:
67+
offsetX = 0
68+
offsetY = (newDimensions.height - originalHeight) / 2
69+
case .center:
70+
offsetX = (newDimensions.width - originalWidth) / 2
71+
offsetY = (newDimensions.height - originalHeight) / 2
72+
case .middleRight:
73+
offsetX = newDimensions.width - originalWidth
74+
offsetY = (newDimensions.height - originalHeight) / 2
75+
case .bottomLeft:
76+
offsetX = 0
77+
offsetY = newDimensions.height - originalHeight
78+
case .bottomCenter:
79+
offsetX = (newDimensions.width - originalWidth) / 2
80+
offsetY = newDimensions.height - originalHeight
81+
case .bottomRight:
82+
offsetX = newDimensions.width - originalWidth
83+
offsetY = newDimensions.height - originalHeight
84+
case let .custom(x, y):
85+
offsetX = x
86+
offsetY = y
87+
}
88+
89+
let colorSpace = CGColorSpaceCreateDeviceRGB()
90+
let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue
91+
92+
// Create padded image
93+
guard let paddedContext = CGContext(data: nil, width: newDimensions.width, height: newDimensions.height, bitsPerComponent: 8, bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfo) else {
94+
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not create padded image context."])
95+
}
96+
paddedContext.draw(cgImage, in: CGRect(x: offsetX, y: offsetY, width: originalWidth, height: originalHeight))
97+
guard let paddedCGImage = paddedContext.makeImage(), let paddedImageData = paddedCGImage.toData() else {
98+
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not get padded image data."])
99+
}
100+
101+
// Create mask
102+
guard let maskContext = CGContext(data: nil, width: newDimensions.width, height: newDimensions.height, bitsPerComponent: 8, bytesPerRow: 0, space: CGColorSpaceCreateDeviceGray(), bitmapInfo: CGImageAlphaInfo.none.rawValue) else {
103+
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not create mask context."])
104+
}
105+
maskContext.setFillColor(gray: 1.0, alpha: 1.0)
106+
maskContext.fill(CGRect(x: 0, y: 0, width: newDimensions.width, height: newDimensions.height))
107+
maskContext.setFillColor(gray: 0.0, alpha: 1.0)
108+
maskContext.fill(CGRect(x: offsetX, y: offsetY, width: originalWidth, height: originalHeight))
109+
guard let maskCGImage = maskContext.makeImage(), let maskData = maskCGImage.toData() else {
110+
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not get mask data."])
111+
}
112+
113+
return [ImagenRawImage(data: paddedImageData), ImagenMaskReference(data: maskData)]
114+
}
115+
}
116+
117+
extension CGImage {
118+
static func fromData(_ data: Data) -> CGImage? {
119+
guard let provider = CGDataProvider(data: data as CFData) else { return nil }
120+
return CGImage(pngDataProviderSource: provider, decode: nil, shouldInterpolate: true, intent: .defaultIntent)
121+
}
122+
123+
func toData() -> Data? {
124+
guard let mutableData = CFDataCreateMutable(nil, 0),
125+
let destination = CGImageDestinationCreateWithData(mutableData, "public.png" as CFString, 1, nil) else { return nil }
126+
CGImageDestinationAddImage(destination, self, nil)
127+
guard CGImageDestinationFinalize(destination) else { return nil }
128+
return mutableData as Data
129+
}
130+
}

0 commit comments

Comments
 (0)