Skip to content

Commit 2fce321

Browse files
authored
Update image processing in multimodal runner tests. (pytorch#14608)
Summary: . Differential Revision: D83272875
1 parent 87e9c16 commit 2fce321

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,26 @@ import XCTest
1111

1212
extension UIImage {
1313
func asImage() -> Image {
14-
let targetWidth = 336
15-
let scaledHeight = Int((Double(targetWidth) * Double(size.height) / Double(size.width)).rounded())
14+
let targetSide = CGFloat(336)
15+
let scale = max(targetSide / size.width, targetSide / size.height)
16+
let scaledSize = CGSize(width: size.width * scale, height: size.height * scale)
1617
let format = UIGraphicsImageRendererFormat.default()
1718
format.scale = 1
18-
let resizedImage = UIGraphicsImageRenderer(size: CGSize(width: targetWidth, height: scaledHeight), format: format).image { _ in
19-
draw(in: CGRect(origin: .zero, size: CGSize(width: targetWidth, height: scaledHeight)))
19+
let scaledImage = UIGraphicsImageRenderer(size: scaledSize, format: format).image { _ in
20+
draw(in: CGRect(origin: .zero, size: scaledSize))
2021
}
21-
let resizedCGImage = resizedImage.cgImage!
22-
let imageWidth = resizedCGImage.width
23-
let imageHeight = resizedCGImage.height
22+
guard let scaledCGImage = scaledImage.cgImage else {
23+
return Image(data: Data(), width: 336, height: 336, channels: 3)
24+
}
25+
let cropRect = CGRect(
26+
x: ((scaledSize.width - targetSide) * 0.5).rounded(.down),
27+
y: ((scaledSize.height - targetSide) * 0.5).rounded(.down),
28+
width: targetSide.rounded(.down),
29+
height: targetSide.rounded(.down)
30+
)
31+
let croppedCGImage = scaledCGImage.cropping(to: cropRect) ?? scaledCGImage
32+
let imageWidth = croppedCGImage.width
33+
let imageHeight = croppedCGImage.height
2434
let pixelCount = imageWidth * imageHeight
2535
var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * 4)
2636
let context = CGContext(
@@ -32,15 +42,15 @@ extension UIImage {
3242
space: CGColorSpaceCreateDeviceRGB(),
3343
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue
3444
)!
35-
context.draw(resizedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight))
45+
context.draw(croppedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight))
3646
var planarRGB = [UInt8](repeating: 0, count: pixelCount * 3)
3747
for pixelIndex in 0..<pixelCount {
3848
let sourceOffset = pixelIndex * 4
3949
planarRGB[pixelIndex] = rgbaBuffer[sourceOffset]
4050
planarRGB[pixelIndex + pixelCount] = rgbaBuffer[sourceOffset + 1]
4151
planarRGB[pixelIndex + pixelCount * 2] = rgbaBuffer[sourceOffset + 2]
4252
}
43-
return Image(data: Data(planarRGB), width: targetWidth, height: scaledHeight, channels: 3)
53+
return Image(data: Data(planarRGB), width: 336, height: 336, channels: 3)
4454
}
4555
}
4656

@@ -55,7 +65,7 @@ class MultimodalRunnerTest: XCTestCase {
5565
guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
5666
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"),
5767
let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"),
58-
let image = UIImage(contentsOfFile: imagePath) else {
68+
let uiImage = UIImage(contentsOfFile: imagePath) else {
5969
XCTFail("Couldn't find model or tokenizer files")
6070
return
6171
}
@@ -65,7 +75,7 @@ class MultimodalRunnerTest: XCTestCase {
6575
do {
6676
try runner.generate([
6777
MultimodalInput(systemPrompt),
68-
MultimodalInput(image.asImage()),
78+
MultimodalInput(uiImage.asImage()),
6979
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
7080
], sequenceLength: sequenceLength) { token in
7181
text += token
@@ -80,7 +90,7 @@ class MultimodalRunnerTest: XCTestCase {
8090
do {
8191
try runner.generate([
8292
MultimodalInput(systemPrompt),
83-
MultimodalInput(image.asImage()),
93+
MultimodalInput(uiImage.asImage()),
8494
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
8595
], sequenceLength: sequenceLength) { token in
8696
text += token

0 commit comments

Comments
 (0)