Skip to content

Commit d8a2126

Browse files
authored
Add Gemma 3 test.
Differential Revision: D84001548 Pull Request resolved: #14825
1 parent 270873f commit d8a2126

File tree

4 files changed

+233
-50
lines changed

4 files changed

+233
-50
lines changed

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ __attribute__((objc_subclassing_restricted))
4444
channels:(NSInteger)channels
4545
NS_DESIGNATED_INITIALIZER;
4646

47+
- (instancetype)initWithFloatData:(NSData *)data
48+
width:(NSInteger)width
49+
height:(NSInteger)height
50+
channels:(NSInteger)channels
51+
NS_DESIGNATED_INITIALIZER;
52+
4753
@property(nonatomic, readonly) NSData *data;
4854

4955
@property(nonatomic, readonly) NSInteger width;
@@ -52,6 +58,8 @@ __attribute__((objc_subclassing_restricted))
5258

5359
@property(nonatomic, readonly) NSInteger channels;
5460

61+
@property(nonatomic, readonly) BOOL isFloat;
62+
5563
+ (instancetype)new NS_UNAVAILABLE;
5664
- (instancetype)init NS_UNAVAILABLE;
5765

@@ -80,6 +88,12 @@ __attribute__((objc_subclassing_restricted))
8088
frames:(NSInteger)frames
8189
NS_DESIGNATED_INITIALIZER;
8290

91+
- (instancetype)initWithFloatData:(NSData *)data
92+
batchSize:(NSInteger)batchSize
93+
bins:(NSInteger)bins
94+
frames:(NSInteger)frames
95+
NS_DESIGNATED_INITIALIZER;
96+
8397
@property(nonatomic, readonly) NSData *data;
8498

8599
@property(nonatomic, readonly) NSInteger batchSize;
@@ -88,6 +102,8 @@ __attribute__((objc_subclassing_restricted))
88102

89103
@property(nonatomic, readonly) NSInteger frames;
90104

105+
@property(nonatomic, readonly) BOOL isFloat;
106+
91107
+ (instancetype)new NS_UNAVAILABLE;
92108
- (instancetype)init NS_UNAVAILABLE;
93109

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ - (instancetype)initWithData:(NSData *)data
3232
_width = width;
3333
_height = height;
3434
_channels = channels;
35+
_isFloat = NO;
36+
}
37+
return self;
38+
}
39+
40+
- (instancetype)initWithFloatData:(NSData *)data
41+
width:(NSInteger)width
42+
height:(NSInteger)height
43+
channels:(NSInteger)channels {
44+
self = [super init];
45+
if (self) {
46+
_data = [data copy];
47+
_width = width;
48+
_height = height;
49+
_channels = channels;
50+
_isFloat = YES;
3551
}
3652
return self;
3753
}
@@ -53,6 +69,22 @@ - (instancetype)initWithData:(NSData *)data
5369
_batchSize = batchSize;
5470
_bins = bins;
5571
_frames = frames;
72+
_isFloat = NO;
73+
}
74+
return self;
75+
}
76+
77+
- (instancetype)initWithFloatData:(NSData *)data
78+
batchSize:(NSInteger)batchSize
79+
bins:(NSInteger)bins
80+
frames:(NSInteger)frames {
81+
self = [super init];
82+
if (self) {
83+
_data = [data copy];
84+
_batchSize = batchSize;
85+
_bins = bins;
86+
_frames = frames;
87+
_isFloat = YES;
5688
}
5789
return self;
5890
}
@@ -170,20 +202,58 @@ - (BOOL)generateWithInputs:(NSArray<ExecuTorchLLMMultimodalInput *> *)inputs
170202
return NO;
171203
}
172204
std::vector<llm::MultimodalInput> nativeInputs;
205+
nativeInputs.reserve((size_t)inputs.count);
173206
for (ExecuTorchLLMMultimodalInput *input in inputs) {
174207
switch (input.type) {
175208
case ExecuTorchLLMMultimodalInputTypeText:
176209
nativeInputs.emplace_back(llm::MultimodalInput(input.text.UTF8String));
177210
break;
178211
case ExecuTorchLLMMultimodalInputTypeImage: {
179212
ExecuTorchLLMImage *image = input.image;
180-
std::vector<uint8_t> data((uint8_t *)image.data.bytes, (uint8_t *)image.data.bytes + image.data.length);
181-
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
182-
std::move(data),
183-
(int32_t)image.width,
184-
(int32_t)image.height,
185-
(int32_t)image.channels
186-
)));
213+
if (image.isFloat) {
214+
const float *buffer = (const float *)image.data.bytes;
215+
size_t elementCount = (size_t)image.data.length / sizeof(float);
216+
std::vector<float> data(buffer, buffer + elementCount);
217+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
218+
std::move(data),
219+
(int32_t)image.width,
220+
(int32_t)image.height,
221+
(int32_t)image.channels
222+
)));
223+
} else {
224+
const uint8_t *buffer = (const uint8_t *)image.data.bytes;
225+
std::vector<uint8_t> data(buffer, buffer + image.data.length);
226+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
227+
std::move(data),
228+
(int32_t)image.width,
229+
(int32_t)image.height,
230+
(int32_t)image.channels
231+
)));
232+
}
233+
break;
234+
}
235+
case ExecuTorchLLMMultimodalInputTypeAudio: {
236+
ExecuTorchLLMAudio *audio = input.audio;
237+
if (audio.isFloat) {
238+
const float *buffer = (const float *)audio.data.bytes;
239+
size_t elementCount = (size_t)audio.data.length / sizeof(float);
240+
std::vector<float> data(buffer, buffer + elementCount);
241+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Audio(
242+
std::move(data),
243+
(int32_t)audio.batchSize,
244+
(int32_t)audio.bins,
245+
(int32_t)audio.frames
246+
)));
247+
} else {
248+
const uint8_t *buffer = (const uint8_t *)audio.data.bytes;
249+
std::vector<uint8_t> data(buffer, buffer + audio.data.length);
250+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Audio(
251+
std::move(data),
252+
(int32_t)audio.batchSize,
253+
(int32_t)audio.bins,
254+
(int32_t)audio.frames
255+
)));
256+
}
187257
break;
188258
}
189259
default: {

extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift

Lines changed: 138 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,60 +10,157 @@ import ExecuTorchLLM
1010
import XCTest
1111

1212
extension UIImage {
13-
func asImage() -> Image {
14-
let targetSide = CGFloat(336)
15-
let scale = max(targetSide / size.width, targetSide / size.height)
16-
let scaledSize = CGSize(width: size.width * scale, height: size.height * scale)
13+
func centerCropped(to sideSize: CGFloat) -> UIImage {
14+
precondition(sideSize > 0)
1715
let format = UIGraphicsImageRendererFormat.default()
1816
format.scale = 1
19-
let scaledImage = UIGraphicsImageRenderer(size: scaledSize, format: format).image { _ in
20-
draw(in: CGRect(origin: .zero, size: scaledSize))
21-
}
22-
guard let scaledCGImage = scaledImage.cgImage else {
23-
return Image(data: Data(), width: 336, height: 336, channels: 3)
24-
}
25-
let cropRect = CGRect(
26-
x: ((scaledSize.width - targetSide) * 0.5).rounded(.down),
27-
y: ((scaledSize.height - targetSide) * 0.5).rounded(.down),
28-
width: targetSide.rounded(.down),
29-
height: targetSide.rounded(.down)
30-
)
31-
let croppedCGImage = scaledCGImage.cropping(to: cropRect) ?? scaledCGImage
32-
let imageWidth = croppedCGImage.width
33-
let imageHeight = croppedCGImage.height
34-
let pixelCount = imageWidth * imageHeight
35-
var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * 4)
36-
let context = CGContext(
17+
format.opaque = false
18+
return UIGraphicsImageRenderer(size: CGSize(width: sideSize, height: sideSize), format: format)
19+
.image { _ in
20+
let scaleFactor = max(sideSize / size.width, sideSize / size.height)
21+
let scaledWidth = size.width * scaleFactor
22+
let scaledHeight = size.height * scaleFactor
23+
let originX = (sideSize - scaledWidth) / 2
24+
let originY = (sideSize - scaledHeight) / 2
25+
draw(in: CGRect(x: originX, y: originY, width: scaledWidth, height: scaledHeight))
26+
}
27+
}
28+
29+
func rgbBytes() -> [UInt8]? {
30+
guard let cgImage = cgImage else { return nil }
31+
let pixelWidth = Int(cgImage.width)
32+
let pixelHeight = Int(cgImage.height)
33+
let pixelCount = pixelWidth * pixelHeight
34+
let bytesPerPixel = 4
35+
let bytesPerRow = pixelWidth * bytesPerPixel
36+
var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * bytesPerPixel)
37+
guard let context = CGContext(
3738
data: &rgbaBuffer,
38-
width: imageWidth,
39-
height: imageHeight,
39+
width: pixelWidth,
40+
height: pixelHeight,
4041
bitsPerComponent: 8,
41-
bytesPerRow: imageWidth * 4,
42+
bytesPerRow: bytesPerRow,
4243
space: CGColorSpaceCreateDeviceRGB(),
4344
bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue
44-
)!
45-
context.draw(croppedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight))
46-
var planarRGB = [UInt8](repeating: 0, count: pixelCount * 3)
45+
) else { return nil }
46+
47+
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: pixelWidth, height: pixelHeight))
48+
49+
var rgbBytes = [UInt8](repeating: 0, count: pixelCount * 3)
4750
for pixelIndex in 0..<pixelCount {
48-
let sourceOffset = pixelIndex * 4
49-
planarRGB[pixelIndex] = rgbaBuffer[sourceOffset]
50-
planarRGB[pixelIndex + pixelCount] = rgbaBuffer[sourceOffset + 1]
51-
planarRGB[pixelIndex + pixelCount * 2] = rgbaBuffer[sourceOffset + 2]
51+
let sourceIndex = pixelIndex * bytesPerPixel
52+
rgbBytes[pixelIndex] = rgbaBuffer[sourceIndex + 0]
53+
rgbBytes[pixelIndex + pixelCount] = rgbaBuffer[sourceIndex + 1]
54+
rgbBytes[pixelIndex + 2 * pixelCount] = rgbaBuffer[sourceIndex + 2]
5255
}
53-
return Image(data: Data(planarRGB), width: 336, height: 336, channels: 3)
56+
return rgbBytes
57+
}
58+
59+
func rgbBytesNormalized(mean: [Float] = [0, 0, 0], std: [Float] = [1, 1, 1]) -> [Float]? {
60+
precondition(mean.count == 3 && std.count == 3)
61+
precondition(std[0] != 0 && std[1] != 0 && std[2] != 0)
62+
guard let rgbBytes = rgbBytes() else { return nil }
63+
let pixelCount = rgbBytes.count / 3
64+
var rgbBytesNormalized = [Float](repeating: 0, count: pixelCount * 3)
65+
for pixelIndex in 0..<pixelCount {
66+
rgbBytesNormalized[pixelIndex] =
67+
(Float(rgbBytes[pixelIndex]) / 255.0 - mean[0]) / std[0]
68+
rgbBytesNormalized[pixelIndex + pixelCount] =
69+
(Float(rgbBytes[pixelIndex + pixelCount]) / 255.0 - mean[1]) / std[1]
70+
rgbBytesNormalized[pixelIndex + 2 * pixelCount] =
71+
(Float(rgbBytes[pixelIndex + 2 * pixelCount]) / 255.0 - mean[2]) / std[2]
72+
}
73+
return rgbBytesNormalized
74+
}
75+
76+
func asImage(_ sideSize: CGFloat) -> Image {
77+
return Image(
78+
data: Data(centerCropped(to: sideSize).rgbBytes() ?? []),
79+
width: Int(sideSize),
80+
height: Int(sideSize),
81+
channels: 3
82+
)
83+
}
84+
85+
func asNormalizedImage(
86+
_ sideSize: CGFloat,
87+
mean: [Float] = [0.485, 0.456, 0.406],
88+
std: [Float] = [0.229, 0.224, 0.225]
89+
) -> Image {
90+
return Image(
91+
float: (centerCropped(to: sideSize).rgbBytesNormalized(mean: mean, std: std) ?? []).withUnsafeBufferPointer { Data(buffer: $0) },
92+
width: Int(sideSize),
93+
height: Int(sideSize),
94+
channels: 3
95+
)
5496
}
5597
}
5698

5799
class MultimodalRunnerTest: XCTestCase {
58-
let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
59-
let assistantPrompt = "ASSISTANT: "
100+
let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
60101
let userPrompt = "What's on the picture?"
61-
let sequenceLength = 768
102+
103+
func testGemma() {
104+
let chatTemplate = "<start_of_turn>user\n%@<end_of_turn>\n<start_of_turn>model"
105+
let sideSize: CGFloat = 896
106+
let sequenceLength = 768
107+
let bundle = Bundle(for: type(of: self))
108+
guard let modelPath = bundle.path(forResource: "gemma3", ofType: "pte"),
109+
let tokenizerPath = bundle.path(forResource: "gemma3_tokenizer", ofType: "model"),
110+
let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"),
111+
let uiImage = UIImage(contentsOfFile: imagePath) else {
112+
XCTFail("Couldn't find model or tokenizer files")
113+
return
114+
}
115+
let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
116+
var text = ""
117+
118+
do {
119+
try runner.generate([
120+
MultimodalInput(systemPrompt),
121+
MultimodalInput(uiImage.asNormalizedImage(sideSize)),
122+
MultimodalInput(String(format: chatTemplate, userPrompt)),
123+
], Config {
124+
$0.sequenceLength = sequenceLength
125+
}) { token in
126+
text += token
127+
if token == "<end_of_turn>" {
128+
runner.stop()
129+
}
130+
}
131+
} catch {
132+
XCTFail("Failed to generate text with error \(error)")
133+
}
134+
XCTAssertTrue(text.lowercased().contains("waterfall"))
135+
136+
text = ""
137+
runner.reset()
138+
do {
139+
try runner.generate([
140+
MultimodalInput(systemPrompt),
141+
MultimodalInput(uiImage.asNormalizedImage(sideSize)),
142+
MultimodalInput(String(format: chatTemplate, userPrompt)),
143+
], Config {
144+
$0.sequenceLength = sequenceLength
145+
}) { token in
146+
text += token
147+
if token == "<end_of_turn>" {
148+
runner.stop()
149+
}
150+
}
151+
} catch {
152+
XCTFail("Failed to generate text with error \(error)")
153+
}
154+
XCTAssertTrue(text.lowercased().contains("waterfall"))
155+
}
62156

63157
func testLLaVA() {
158+
let chatTemplate = "USER: %@ ASSISTANT: "
159+
let sideSize: CGFloat = 336
160+
let sequenceLength = 768
64161
let bundle = Bundle(for: type(of: self))
65162
guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
66-
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"),
163+
let tokenizerPath = bundle.path(forResource: "llava_tokenizer", ofType: "bin"),
67164
let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"),
68165
let uiImage = UIImage(contentsOfFile: imagePath) else {
69166
XCTFail("Couldn't find model or tokenizer files")
@@ -75,8 +172,8 @@ class MultimodalRunnerTest: XCTestCase {
75172
do {
76173
try runner.generate([
77174
MultimodalInput(systemPrompt),
78-
MultimodalInput(uiImage.asImage()),
79-
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
175+
MultimodalInput(uiImage.asImage(sideSize)),
176+
MultimodalInput(String(format: chatTemplate, userPrompt)),
80177
], Config {
81178
$0.sequenceLength = sequenceLength
82179
}) { token in
@@ -92,8 +189,8 @@ class MultimodalRunnerTest: XCTestCase {
92189
do {
93190
try runner.generate([
94191
MultimodalInput(systemPrompt),
95-
MultimodalInput(uiImage.asImage()),
96-
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
192+
MultimodalInput(uiImage.asImage(sideSize)),
193+
MultimodalInput(String(format: chatTemplate, userPrompt)),
97194
], Config {
98195
$0.sequenceLength = sequenceLength
99196
}) { token in

extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TextRunnerTest: XCTestCase {
4242
func testLLaMA() {
4343
let bundle = Bundle(for: type(of: self))
4444
guard let modelPath = bundle.path(forResource: "llama3_2-1B", ofType: "pte"),
45-
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "model") else {
45+
let tokenizerPath = bundle.path(forResource: "llama_tokenizer", ofType: "model") else {
4646
XCTFail("Couldn't find model or tokenizer files")
4747
return
4848
}
@@ -77,7 +77,7 @@ class TextRunnerTest: XCTestCase {
7777
func testPhi4() {
7878
let bundle = Bundle(for: type(of: self))
7979
guard let modelPath = bundle.path(forResource: "phi4-mini", ofType: "pte"),
80-
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "json") else {
80+
let tokenizerPath = bundle.path(forResource: "phi_tokenizer", ofType: "json") else {
8181
XCTFail("Couldn't find model or tokenizer files")
8282
return
8383
}

0 commit comments

Comments
 (0)