Skip to content

Commit 748caab

Browse files
committed
Address GCA comments
1 parent d85fed2 commit 748caab

File tree

4 files changed

+115
-29
lines changed

4 files changed

+115
-29
lines changed

FirebaseAI/Sources/Types/Internal/Imagen/ImageGenerationInstance.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,16 @@ extension ImageGenerationInstance: Encodable {
4545
if let referenceImages = referenceImages {
4646
var imagesContainer = container.nestedUnkeyedContainer(forKey: .referenceImages)
4747
for image in referenceImages {
48-
if let rawImage = image as? ImagenRawImage {
48+
switch image {
49+
case let rawImage as ImagenRawImage:
4950
try imagesContainer.encode(rawImage)
50-
} else if let mask = image as? ImagenMaskReference {
51+
case let mask as ImagenMaskReference:
5152
try imagesContainer.encode(mask)
53+
default:
54+
throw EncodingError.invalidValue(image, EncodingError.Context(
55+
codingPath: imagesContainer.codingPath,
56+
debugDescription: "Unknown ImagenReferenceImage type."
57+
))
5258
}
5359
}
5460
}

FirebaseAI/Sources/Types/Public/Imagen/ImagenMaskReference.swift

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,33 @@ public struct ImagenMaskReference: ImagenReferenceImage, Encodable {
3535
try container.encode(data.base64EncodedString(), forKey: .data)
3636
}
3737

38+
/// Errors that can occur during outpainting.
39+
public enum OutpaintingError: Error {
40+
/// The provided image data could not be decoded.
41+
case invalidImageData
42+
/// The new dimensions are smaller than the original image.
43+
case dimensionsTooSmall
44+
/// The image context could not be created.
45+
case contextCreationFailed
46+
/// The image could not be created from the context.
47+
case imageCreationFailed
48+
/// The image data could not be created from the image.
49+
case dataCreationFailed
50+
}
51+
3852
static func generateMaskAndPadForOutpainting(image: ImagenInlineImage,
3953
newDimensions: Dimensions,
4054
newPosition: ImagenImagePlacement) throws
4155
-> [ImagenReferenceImage] {
4256
guard let cgImage = CGImage.fromData(image.data) else {
43-
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not create image from data."])
57+
throw OutpaintingError.invalidImageData
4458
}
4559

4660
let originalWidth = cgImage.width
4761
let originalHeight = cgImage.height
4862

4963
guard newDimensions.width >= originalWidth, newDimensions.height >= originalHeight else {
50-
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "New dimensions must be larger than the original image."])
64+
throw OutpaintingError.dimensionsTooSmall
5165
}
5266

5367
let offsetX: Int
@@ -90,41 +104,66 @@ public struct ImagenMaskReference: ImagenReferenceImage, Encodable {
90104
let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue
91105

92106
// Create padded image
93-
guard let paddedContext = CGContext(data: nil, width: newDimensions.width, height: newDimensions.height, bitsPerComponent: 8, bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfo) else {
94-
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not create padded image context."])
107+
guard let paddedContext = CGContext(
108+
data: nil,
109+
width: newDimensions.width,
110+
height: newDimensions.height,
111+
bitsPerComponent: 8,
112+
bytesPerRow: 0,
113+
space: colorSpace,
114+
bitmapInfo: bitmapInfo
115+
) else {
116+
throw OutpaintingError.contextCreationFailed
95117
}
96-
paddedContext.draw(cgImage, in: CGRect(x: offsetX, y: offsetY, width: originalWidth, height: originalHeight))
97-
guard let paddedCGImage = paddedContext.makeImage(), let paddedImageData = paddedCGImage.toData() else {
98-
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not get padded image data."])
118+
paddedContext.draw(
119+
cgImage,
120+
in: CGRect(x: offsetX, y: offsetY, width: originalWidth, height: originalHeight)
121+
)
122+
guard let paddedCGImage = paddedContext.makeImage(),
123+
let paddedImageData = paddedCGImage.toData() else {
124+
throw OutpaintingError.imageCreationFailed
99125
}
100126

101127
// Create mask
102-
guard let maskContext = CGContext(data: nil, width: newDimensions.width, height: newDimensions.height, bitsPerComponent: 8, bytesPerRow: 0, space: CGColorSpaceCreateDeviceGray(), bitmapInfo: CGImageAlphaInfo.none.rawValue) else {
103-
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not create mask context."])
128+
guard let maskContext = CGContext(
129+
data: nil,
130+
width: newDimensions.width,
131+
height: newDimensions.height,
132+
bitsPerComponent: 8,
133+
bytesPerRow: 0,
134+
space: CGColorSpaceCreateDeviceGray(),
135+
bitmapInfo: CGImageAlphaInfo.none.rawValue
136+
) else {
137+
throw OutpaintingError.contextCreationFailed
104138
}
105139
maskContext.setFillColor(gray: 1.0, alpha: 1.0)
106140
maskContext.fill(CGRect(x: 0, y: 0, width: newDimensions.width, height: newDimensions.height))
107141
maskContext.setFillColor(gray: 0.0, alpha: 1.0)
108142
maskContext.fill(CGRect(x: offsetX, y: offsetY, width: originalWidth, height: originalHeight))
109143
guard let maskCGImage = maskContext.makeImage(), let maskData = maskCGImage.toData() else {
110-
throw NSError(domain: "com.google.firebase.ai", code: 0, userInfo: [NSLocalizedDescriptionKey: "Could not get mask data."])
144+
throw OutpaintingError.dataCreationFailed
111145
}
112146

113147
return [ImagenRawImage(data: paddedImageData), ImagenMaskReference(data: maskData)]
114148
}
115149
}
116150

117151
extension CGImage {
118-
static func fromData(_ data: Data) -> CGImage? {
119-
guard let provider = CGDataProvider(data: data as CFData) else { return nil }
120-
return CGImage(pngDataProviderSource: provider, decode: nil, shouldInterpolate: true, intent: .defaultIntent)
121-
}
152+
static func fromData(_ data: Data) -> CGImage? {
153+
guard let source = CGImageSourceCreateWithData(data as CFData, nil) else { return nil }
154+
return CGImageSourceCreateImageAtIndex(source, 0, nil)
155+
}
122156

123-
func toData() -> Data? {
124-
guard let mutableData = CFDataCreateMutable(nil, 0),
125-
let destination = CGImageDestinationCreateWithData(mutableData, "public.png" as CFString, 1, nil) else { return nil }
126-
CGImageDestinationAddImage(destination, self, nil)
127-
guard CGImageDestinationFinalize(destination) else { return nil }
128-
return mutableData as Data
129-
}
130-
}
157+
func toData() -> Data? {
158+
guard let mutableData = CFDataCreateMutable(nil, 0),
159+
let destination = CGImageDestinationCreateWithData(
160+
mutableData,
161+
"public.png" as CFString,
162+
1,
163+
nil
164+
) else { return nil }
165+
CGImageDestinationAddImage(destination, self, nil)
166+
guard CGImageDestinationFinalize(destination) else { return nil }
167+
return mutableData as Data
168+
}
169+
}

FirebaseAI/Sources/Types/Public/Imagen/ImagenModel.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@ public final class ImagenModel {
167167
public func inpaintImage(image: ImagenInlineImage,
168168
prompt: String,
169169
mask: ImagenMaskReference,
170-
config: ImagenEditingConfig) async throws
170+
editSteps: Int? = nil) async throws
171171
-> ImagenGenerationResponse<ImagenInlineImage> {
172172
return try await editImage(
173173
referenceImages: [ImagenRawImage(data: image.data), mask],
174174
prompt: prompt,
175-
config: config
175+
config: ImagenEditingConfig(editMode: .inpaint, editSteps: editSteps)
176176
)
177177
}
178178

@@ -195,7 +195,7 @@ public final class ImagenModel {
195195
newDimensions: Dimensions,
196196
newPosition: ImagenImagePlacement = .center,
197197
prompt: String = "",
198-
config: ImagenEditingConfig? = nil) async throws
198+
editSteps: Int? = nil) async throws
199199
-> ImagenGenerationResponse<ImagenInlineImage> {
200200
let referenceImages = try ImagenMaskReference.generateMaskAndPadForOutpainting(
201201
image: image,
@@ -205,7 +205,7 @@ public final class ImagenModel {
205205
return try await editImage(
206206
referenceImages: referenceImages,
207207
prompt: prompt,
208-
config: ImagenEditingConfig(editMode: .outpaint, editSteps: config?.editSteps)
208+
config: ImagenEditingConfig(editMode: .outpaint, editSteps: editSteps)
209209
)
210210
}
211211

FirebaseAI/Tests/Unit/ImagenMaskReferenceTests.swift

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,51 @@ final class ImagenMaskReferenceTests: XCTestCase {
5252
XCTAssertEqual(maskCGImage.height, newHeight)
5353
}
5454

55+
func testGenerateMaskAndPadForOutpainting_invalidData() {
56+
// Setup
57+
let newDimensions = Dimensions(width: 200, height: 200)
58+
let image = ImagenInlineImage(mimeType: "dummy-mime", data: Data())
59+
60+
// Act & Assert
61+
XCTAssertThrowsError(try ImagenMaskReference.generateMaskAndPadForOutpainting(
62+
image: image,
63+
newDimensions: newDimensions,
64+
newPosition: .center
65+
)) { error in
66+
XCTAssertEqual(error as? ImagenMaskReference.OutpaintingError, .invalidImageData)
67+
}
68+
}
69+
70+
func testGenerateMaskAndPadForOutpainting_dimensionsTooSmall() {
71+
// Setup
72+
let newDimensions = Dimensions(width: 50, height: 50)
73+
let image = ImagenInlineImage(
74+
mimeType: "dummy-mime",
75+
data: createDummyImageData(width: 100, height: 100)
76+
)
77+
78+
// Act & Assert
79+
XCTAssertThrowsError(try ImagenMaskReference.generateMaskAndPadForOutpainting(
80+
image: image,
81+
newDimensions: newDimensions,
82+
newPosition: .center
83+
)) { error in
84+
XCTAssertEqual(error as? ImagenMaskReference.OutpaintingError, .dimensionsTooSmall)
85+
}
86+
}
87+
5588
private func createDummyImageData(width: Int, height: Int) -> Data {
5689
let colorSpace = CGColorSpaceCreateDeviceRGB()
5790
let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue
58-
let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 0, space: colorSpace, bitmapInfo: bitmapInfo)!
91+
let context = CGContext(
92+
data: nil,
93+
width: width,
94+
height: height,
95+
bitsPerComponent: 8,
96+
bytesPerRow: 0,
97+
space: colorSpace,
98+
bitmapInfo: bitmapInfo
99+
)!
59100
context.setFillColor(red: 1, green: 1, blue: 1, alpha: 1)
60101
context.fill(CGRect(x: 0, y: 0, width: width, height: height))
61102
let cgImage = context.makeImage()!

0 commit comments

Comments
 (0)