Skip to content

Commit c652fa6

Browse files
shoumikhinpytorchbot
authored andcommitted
Update image processing in multimodal runner tests. (#14608)
Summary: . Differential Revision: D83272875 (cherry picked from commit 2fce321)
1 parent fc48f5c commit c652fa6

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,26 @@ import XCTest
1111

1212
extension UIImage {
1313
func asImage() -> Image {
14-
let targetWidth = 336
15-
let scaledHeight = Int((Double(targetWidth) * Double(size.height) / Double(size.width)).rounded())
14+
let targetSide = CGFloat(336)
15+
let scale = max(targetSide / size.width, targetSide / size.height)
16+
let scaledSize = CGSize(width: size.width * scale, height: size.height * scale)
1617
let format = UIGraphicsImageRendererFormat.default()
1718
format.scale = 1
18-
let resizedImage = UIGraphicsImageRenderer(size: CGSize(width: targetWidth, height: scaledHeight), format: format).image { _ in
19-
draw(in: CGRect(origin: .zero, size: CGSize(width: targetWidth, height: scaledHeight)))
19+
let scaledImage = UIGraphicsImageRenderer(size: scaledSize, format: format).image { _ in
20+
draw(in: CGRect(origin: .zero, size: scaledSize))
2021
}
21-
let resizedCGImage = resizedImage.cgImage!
22-
let imageWidth = resizedCGImage.width
23-
let imageHeight = resizedCGImage.height
22+
guard let scaledCGImage = scaledImage.cgImage else {
23+
return Image(data: Data(), width: 336, height: 336, channels: 3)
24+
}
25+
let cropRect = CGRect(
26+
x: ((scaledSize.width - targetSide) * 0.5).rounded(.down),
27+
y: ((scaledSize.height - targetSide) * 0.5).rounded(.down),
28+
width: targetSide.rounded(.down),
29+
height: targetSide.rounded(.down)
30+
)
31+
let croppedCGImage = scaledCGImage.cropping(to: cropRect) ?? scaledCGImage
32+
let imageWidth = croppedCGImage.width
33+
let imageHeight = croppedCGImage.height
2434
let pixelCount = imageWidth * imageHeight
2535
var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * 4)
2636
let context = CGContext(
@@ -32,15 +42,15 @@ extension UIImage {
3242
space: CGColorSpaceCreateDeviceRGB(),
3343
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue
3444
)!
35-
context.draw(resizedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight))
45+
context.draw(croppedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight))
3646
var planarRGB = [UInt8](repeating: 0, count: pixelCount * 3)
3747
for pixelIndex in 0..<pixelCount {
3848
let sourceOffset = pixelIndex * 4
3949
planarRGB[pixelIndex] = rgbaBuffer[sourceOffset]
4050
planarRGB[pixelIndex + pixelCount] = rgbaBuffer[sourceOffset + 1]
4151
planarRGB[pixelIndex + pixelCount * 2] = rgbaBuffer[sourceOffset + 2]
4252
}
43-
return Image(data: Data(planarRGB), width: targetWidth, height: scaledHeight, channels: 3)
53+
return Image(data: Data(planarRGB), width: 336, height: 336, channels: 3)
4454
}
4555
}
4656

@@ -50,7 +60,7 @@ class MultimodalRunnerTest: XCTestCase {
5060
guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
5161
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"),
5262
let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"),
53-
let image = UIImage(contentsOfFile: imagePath) else {
63+
let uiImage = UIImage(contentsOfFile: imagePath) else {
5464
XCTFail("Couldn't find model or tokenizer files")
5565
return
5666
}
@@ -59,10 +69,25 @@ class MultimodalRunnerTest: XCTestCase {
5969

6070
do {
6171
try runner.generate([
62-
MultimodalInput("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "),
63-
MultimodalInput(image.asImage()),
64-
MultimodalInput("What's on the picture? ASSISTANT: "),
65-
], sequenceLength: 768) { token in
72+
MultimodalInput(systemPrompt),
73+
MultimodalInput(uiImage.asImage()),
74+
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
75+
], sequenceLength: sequenceLength) { token in
76+
text += token
77+
}
78+
} catch {
79+
XCTFail("Failed to generate text with error \(error)")
80+
}
81+
XCTAssertTrue(text.lowercased().contains("waterfall"))
82+
83+
text = ""
84+
runner.reset()
85+
do {
86+
try runner.generate([
87+
MultimodalInput(systemPrompt),
88+
MultimodalInput(uiImage.asImage()),
89+
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
90+
], sequenceLength: sequenceLength) { token in
6691
text += token
6792
}
6893
} catch {

0 commit comments

Comments
 (0)