@@ -11,16 +11,26 @@ import XCTest
1111
1212extension UIImage {
1313 func asImage( ) -> Image {
14- let targetWidth = 336
15- let scaledHeight = Int ( ( Double ( targetWidth) * Double( size. height) / Double( size. width) ) . rounded ( ) )
14+ let targetSide = CGFloat ( 336 )
15+ let scale = max ( targetSide / size. width, targetSide / size. height)
16+ let scaledSize = CGSize ( width: size. width * scale, height: size. height * scale)
1617 let format = UIGraphicsImageRendererFormat . default ( )
1718 format. scale = 1
18- let resizedImage = UIGraphicsImageRenderer ( size: CGSize ( width : targetWidth , height : scaledHeight ) , format: format) . image { _ in
19- draw ( in: CGRect ( origin: . zero, size: CGSize ( width : targetWidth , height : scaledHeight ) ) )
19+ let scaledImage = UIGraphicsImageRenderer ( size: scaledSize , format: format) . image { _ in
20+ draw ( in: CGRect ( origin: . zero, size: scaledSize ) )
2021 }
21- let resizedCGImage = resizedImage. cgImage!
22- let imageWidth = resizedCGImage. width
23- let imageHeight = resizedCGImage. height
22+ guard let scaledCGImage = scaledImage. cgImage else {
23+ return Image ( data: Data ( ) , width: 336 , height: 336 , channels: 3 )
24+ }
25+ let cropRect = CGRect (
26+ x: ( ( scaledSize. width - targetSide) * 0.5 ) . rounded ( . down) ,
27+ y: ( ( scaledSize. height - targetSide) * 0.5 ) . rounded ( . down) ,
28+ width: targetSide. rounded ( . down) ,
29+ height: targetSide. rounded ( . down)
30+ )
31+ let croppedCGImage = scaledCGImage. cropping ( to: cropRect) ?? scaledCGImage
32+ let imageWidth = croppedCGImage. width
33+ let imageHeight = croppedCGImage. height
2434 let pixelCount = imageWidth * imageHeight
2535 var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * 4 )
2636 let context = CGContext (
@@ -32,15 +42,15 @@ extension UIImage {
3242 space: CGColorSpaceCreateDeviceRGB ( ) ,
3343 bitmapInfo: CGImageAlphaInfo . premultipliedLast. rawValue | CGBitmapInfo . byteOrder32Big. rawValue
3444 ) !
35- context. draw ( resizedCGImage , in: CGRect ( x: 0 , y: 0 , width: imageWidth, height: imageHeight) )
45+ context. draw ( croppedCGImage , in: CGRect ( x: 0 , y: 0 , width: imageWidth, height: imageHeight) )
3646 var planarRGB = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
3747 for pixelIndex in 0 ..< pixelCount {
3848 let sourceOffset = pixelIndex * 4
3949 planarRGB [ pixelIndex] = rgbaBuffer [ sourceOffset]
4050 planarRGB [ pixelIndex + pixelCount] = rgbaBuffer [ sourceOffset + 1 ]
4151 planarRGB [ pixelIndex + pixelCount * 2 ] = rgbaBuffer [ sourceOffset + 2 ]
4252 }
43- return Image ( data: Data ( planarRGB) , width: targetWidth , height: scaledHeight , channels: 3 )
53+ return Image ( data: Data ( planarRGB) , width: 336 , height: 336 , channels: 3 )
4454 }
4555}
4656
@@ -50,7 +60,7 @@ class MultimodalRunnerTest: XCTestCase {
5060 guard let modelPath = bundle. path ( forResource: " llava " , ofType: " pte " ) ,
5161 let tokenizerPath = bundle. path ( forResource: " tokenizer " , ofType: " bin " ) ,
5262 let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " jpg " ) ,
53- let image = UIImage ( contentsOfFile: imagePath) else {
63+ let uiImage = UIImage ( contentsOfFile: imagePath) else {
5464 XCTFail ( " Couldn't find model or tokenizer files " )
5565 return
5666 }
@@ -59,10 +69,25 @@ class MultimodalRunnerTest: XCTestCase {
5969
6070 do {
6171 try runner. generate ( [
62- MultimodalInput ( " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: " ) ,
63- MultimodalInput ( image. asImage ( ) ) ,
64- MultimodalInput ( " What's on the picture? ASSISTANT: " ) ,
65- ] , sequenceLength: 768 ) { token in
72+ MultimodalInput ( systemPrompt) ,
73+ MultimodalInput ( uiImage. asImage ( ) ) ,
74+ MultimodalInput ( " \( userPrompt) \( assistantPrompt) " ) ,
75+ ] , sequenceLength: sequenceLength) { token in
76+ text += token
77+ }
78+ } catch {
79+ XCTFail ( " Failed to generate text with error \( error) " )
80+ }
81+ XCTAssertTrue ( text. lowercased ( ) . contains ( " waterfall " ) )
82+
83+ text = " "
84+ runner. reset ( )
85+ do {
86+ try runner. generate ( [
87+ MultimodalInput ( systemPrompt) ,
88+ MultimodalInput ( uiImage. asImage ( ) ) ,
89+ MultimodalInput ( " \( userPrompt) \( assistantPrompt) " ) ,
90+ ] , sequenceLength: sequenceLength) { token in
6691 text += token
6792 }
6893 } catch {
0 commit comments