@@ -10,60 +10,157 @@ import ExecuTorchLLM
10
10
import XCTest
11
11
12
12
extension UIImage {
13
- func asImage( ) -> Image {
14
- let targetSide = CGFloat ( 336 )
15
- let scale = max ( targetSide / size. width, targetSide / size. height)
16
- let scaledSize = CGSize ( width: size. width * scale, height: size. height * scale)
13
+ func centerCropped( to sideSize: CGFloat ) -> UIImage {
14
+ precondition ( sideSize > 0 )
17
15
let format = UIGraphicsImageRendererFormat . default ( )
18
16
format. scale = 1
19
- let scaledImage = UIGraphicsImageRenderer ( size: scaledSize, format: format) . image { _ in
20
- draw ( in: CGRect ( origin: . zero, size: scaledSize) )
21
- }
22
- guard let scaledCGImage = scaledImage. cgImage else {
23
- return Image ( data: Data ( ) , width: 336 , height: 336 , channels: 3 )
24
- }
25
- let cropRect = CGRect (
26
- x: ( ( scaledSize. width - targetSide) * 0.5 ) . rounded ( . down) ,
27
- y: ( ( scaledSize. height - targetSide) * 0.5 ) . rounded ( . down) ,
28
- width: targetSide. rounded ( . down) ,
29
- height: targetSide. rounded ( . down)
30
- )
31
- let croppedCGImage = scaledCGImage. cropping ( to: cropRect) ?? scaledCGImage
32
- let imageWidth = croppedCGImage. width
33
- let imageHeight = croppedCGImage. height
34
- let pixelCount = imageWidth * imageHeight
35
- var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * 4 )
36
- let context = CGContext (
17
+ format. opaque = false
18
+ return UIGraphicsImageRenderer ( size: CGSize ( width: sideSize, height: sideSize) , format: format)
19
+ . image { _ in
20
+ let scaleFactor = max ( sideSize / size. width, sideSize / size. height)
21
+ let scaledWidth = size. width * scaleFactor
22
+ let scaledHeight = size. height * scaleFactor
23
+ let originX = ( sideSize - scaledWidth) / 2
24
+ let originY = ( sideSize - scaledHeight) / 2
25
+ draw ( in: CGRect ( x: originX, y: originY, width: scaledWidth, height: scaledHeight) )
26
+ }
27
+ }
28
+
29
+ func rgbBytes( ) -> [ UInt8 ] ? {
30
+ guard let cgImage = cgImage else { return nil }
31
+ let pixelWidth = Int ( cgImage. width)
32
+ let pixelHeight = Int ( cgImage. height)
33
+ let pixelCount = pixelWidth * pixelHeight
34
+ let bytesPerPixel = 4
35
+ let bytesPerRow = pixelWidth * bytesPerPixel
36
+ var rgbaBuffer = [ UInt8] ( repeating: 0 , count: pixelCount * bytesPerPixel)
37
+ guard let context = CGContext (
37
38
data: & rgbaBuffer,
38
- width: imageWidth ,
39
- height: imageHeight ,
39
+ width: pixelWidth ,
40
+ height: pixelHeight ,
40
41
bitsPerComponent: 8 ,
41
- bytesPerRow: imageWidth * 4 ,
42
+ bytesPerRow: bytesPerRow ,
42
43
space: CGColorSpaceCreateDeviceRGB ( ) ,
43
44
bitmapInfo: CGImageAlphaInfo . premultipliedLast. rawValue | CGBitmapInfo . byteOrder32Big. rawValue
44
- ) !
45
- context. draw ( croppedCGImage, in: CGRect ( x: 0 , y: 0 , width: imageWidth, height: imageHeight) )
46
- var planarRGB = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
45
+ ) else { return nil }
46
+
47
+ context. draw ( cgImage, in: CGRect ( x: 0 , y: 0 , width: pixelWidth, height: pixelHeight) )
48
+
49
+ var rgbBytes = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
47
50
for pixelIndex in 0 ..< pixelCount {
48
- let sourceOffset = pixelIndex * 4
49
- planarRGB [ pixelIndex] = rgbaBuffer [ sourceOffset ]
50
- planarRGB [ pixelIndex + pixelCount] = rgbaBuffer [ sourceOffset + 1 ]
51
- planarRGB [ pixelIndex + pixelCount * 2 ] = rgbaBuffer [ sourceOffset + 2 ]
51
+ let sourceIndex = pixelIndex * bytesPerPixel
52
+ rgbBytes [ pixelIndex] = rgbaBuffer [ sourceIndex + 0 ]
53
+ rgbBytes [ pixelIndex + pixelCount] = rgbaBuffer [ sourceIndex + 1 ]
54
+ rgbBytes [ pixelIndex + 2 * pixelCount ] = rgbaBuffer [ sourceIndex + 2 ]
52
55
}
53
- return Image ( data: Data ( planarRGB) , width: 336 , height: 336 , channels: 3 )
56
+ return rgbBytes
57
+ }
58
+
59
+ func rgbBytesNormalized( mean: [ Float ] = [ 0 , 0 , 0 ] , std: [ Float ] = [ 1 , 1 , 1 ] ) -> [ Float ] ? {
60
+ precondition ( mean. count == 3 && std. count == 3 )
61
+ precondition ( std [ 0 ] != 0 && std [ 1 ] != 0 && std [ 2 ] != 0 )
62
+ guard let rgbBytes = rgbBytes ( ) else { return nil }
63
+ let pixelCount = rgbBytes. count / 3
64
+ var rgbBytesNormalized = [ Float] ( repeating: 0 , count: pixelCount * 3 )
65
+ for pixelIndex in 0 ..< pixelCount {
66
+ rgbBytesNormalized [ pixelIndex] =
67
+ ( Float ( rgbBytes [ pixelIndex] ) / 255.0 - mean[ 0 ] ) / std[ 0 ]
68
+ rgbBytesNormalized [ pixelIndex + pixelCount] =
69
+ ( Float ( rgbBytes [ pixelIndex + pixelCount] ) / 255.0 - mean[ 1 ] ) / std[ 1 ]
70
+ rgbBytesNormalized [ pixelIndex + 2 * pixelCount] =
71
+ ( Float ( rgbBytes [ pixelIndex + 2 * pixelCount] ) / 255.0 - mean[ 2 ] ) / std[ 2 ]
72
+ }
73
+ return rgbBytesNormalized
74
+ }
75
+
76
+ func asImage( _ sideSize: CGFloat ) -> Image {
77
+ return Image (
78
+ data: Data ( centerCropped ( to: sideSize) . rgbBytes ( ) ?? [ ] ) ,
79
+ width: Int ( sideSize) ,
80
+ height: Int ( sideSize) ,
81
+ channels: 3
82
+ )
83
+ }
84
+
85
+ func asNormalizedImage(
86
+ _ sideSize: CGFloat ,
87
+ mean: [ Float ] = [ 0.485 , 0.456 , 0.406 ] ,
88
+ std: [ Float ] = [ 0.229 , 0.224 , 0.225 ]
89
+ ) -> Image {
90
+ return Image (
91
+ float: ( centerCropped ( to: sideSize) . rgbBytesNormalized ( mean: mean, std: std) ?? [ ] ) . withUnsafeBufferPointer { Data ( buffer: $0) } ,
92
+ width: Int ( sideSize) ,
93
+ height: Int ( sideSize) ,
94
+ channels: 3
95
+ )
54
96
}
55
97
}
56
98
57
99
class MultimodalRunnerTest : XCTestCase {
58
- let systemPrompt = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
59
- let assistantPrompt = " ASSISTANT: "
100
+ let systemPrompt = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. "
60
101
let userPrompt = " What's on the picture? "
61
- let sequenceLength = 768
102
+
103
+ func testGemma( ) {
104
+ let chatTemplate = " <start_of_turn>user \n %@<end_of_turn> \n <start_of_turn>model "
105
+ let sideSize : CGFloat = 896
106
+ let sequenceLength = 768
107
+ let bundle = Bundle ( for: type ( of: self ) )
108
+ guard let modelPath = bundle. path ( forResource: " gemma3 " , ofType: " pte " ) ,
109
+ let tokenizerPath = bundle. path ( forResource: " gemma3_tokenizer " , ofType: " model " ) ,
110
+ let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " jpg " ) ,
111
+ let uiImage = UIImage ( contentsOfFile: imagePath) else {
112
+ XCTFail ( " Couldn't find model or tokenizer files " )
113
+ return
114
+ }
115
+ let runner = MultimodalRunner ( modelPath: modelPath, tokenizerPath: tokenizerPath)
116
+ var text = " "
117
+
118
+ do {
119
+ try runner. generate ( [
120
+ MultimodalInput ( systemPrompt) ,
121
+ MultimodalInput ( uiImage. asNormalizedImage ( sideSize) ) ,
122
+ MultimodalInput ( String ( format: chatTemplate, userPrompt) ) ,
123
+ ] , Config {
124
+ $0. sequenceLength = sequenceLength
125
+ } ) { token in
126
+ text += token
127
+ if token == " <end_of_turn> " {
128
+ runner. stop ( )
129
+ }
130
+ }
131
+ } catch {
132
+ XCTFail ( " Failed to generate text with error \( error) " )
133
+ }
134
+ XCTAssertTrue ( text. lowercased ( ) . contains ( " waterfall " ) )
135
+
136
+ text = " "
137
+ runner. reset ( )
138
+ do {
139
+ try runner. generate ( [
140
+ MultimodalInput ( systemPrompt) ,
141
+ MultimodalInput ( uiImage. asNormalizedImage ( sideSize) ) ,
142
+ MultimodalInput ( String ( format: chatTemplate, userPrompt) ) ,
143
+ ] , Config {
144
+ $0. sequenceLength = sequenceLength
145
+ } ) { token in
146
+ text += token
147
+ if token == " <end_of_turn> " {
148
+ runner. stop ( )
149
+ }
150
+ }
151
+ } catch {
152
+ XCTFail ( " Failed to generate text with error \( error) " )
153
+ }
154
+ XCTAssertTrue ( text. lowercased ( ) . contains ( " waterfall " ) )
155
+ }
62
156
63
157
func testLLaVA( ) {
158
+ let chatTemplate = " USER: %@ ASSISTANT: "
159
+ let sideSize : CGFloat = 336
160
+ let sequenceLength = 768
64
161
let bundle = Bundle ( for: type ( of: self ) )
65
162
guard let modelPath = bundle. path ( forResource: " llava " , ofType: " pte " ) ,
66
- let tokenizerPath = bundle. path ( forResource: " tokenizer " , ofType: " bin " ) ,
163
+ let tokenizerPath = bundle. path ( forResource: " llava_tokenizer " , ofType: " bin " ) ,
67
164
let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " jpg " ) ,
68
165
let uiImage = UIImage ( contentsOfFile: imagePath) else {
69
166
XCTFail ( " Couldn't find model or tokenizer files " )
@@ -75,8 +172,8 @@ class MultimodalRunnerTest: XCTestCase {
75
172
do {
76
173
try runner. generate ( [
77
174
MultimodalInput ( systemPrompt) ,
78
- MultimodalInput ( uiImage. asImage ( ) ) ,
79
- MultimodalInput ( " \( userPrompt) \( assistantPrompt ) " ) ,
175
+ MultimodalInput ( uiImage. asImage ( sideSize ) ) ,
176
+ MultimodalInput ( String ( format : chatTemplate , userPrompt) ) ,
80
177
] , Config {
81
178
$0. sequenceLength = sequenceLength
82
179
} ) { token in
@@ -92,8 +189,8 @@ class MultimodalRunnerTest: XCTestCase {
92
189
do {
93
190
try runner. generate ( [
94
191
MultimodalInput ( systemPrompt) ,
95
- MultimodalInput ( uiImage. asImage ( ) ) ,
96
- MultimodalInput ( " \( userPrompt) \( assistantPrompt ) " ) ,
192
+ MultimodalInput ( uiImage. asImage ( sideSize ) ) ,
193
+ MultimodalInput ( String ( format : chatTemplate , userPrompt) ) ,
97
194
] , Config {
98
195
$0. sequenceLength = sequenceLength
99
196
} ) { token in
0 commit comments